[백준] #1626 두 번째로 작은 스패닝 트리

주재완·2025년 6월 29일
0

[C++] 백준

목록 보기
1/1
post-thumbnail

C++로 언어 변경한지는 좀 지났는데, C++로 작성하는 첫 포스팅이 되었네요

https://www.acmicpc.net/problem/1626
문제는 아주 간단합니다. MST 가 아닌 SMST(Second Minimum Spanning Tree)를 구하면 됩니다.

개요

우선 희소 배열에 대한 사전 지식이 있어야 하고, #15481 그래프와 MST를 미리 푸는 것이 좋습니다.

여기서는 희소 배열에 대한 것은 안다고 생각하고, 15481번 과 1626번 순으로 풀이하겠습니다.

15481 - 그래프와 MST(P1)

각각의 간선에 대해서 각 간선을 포함하는 MST를 구할 때 그 가중치 합들을 각각 구하면 되는 문제입니다. 정점의 수는 N, 간선의 수는 M이라 가정합니다.

우선 시간 복잡도를 고려하지 않은 풀이를 생각하면 다음과 같습니다.

  • 간선 정보를 모두 저장한다.
  • 각 간선에 대해서 해당하는 간선이 연결되어 있다고 생각하고 MST를 구한다
  • 가중치 합을 구한다.

MST를 크루스칼 알고리즘으로 구한다고 하면 기본적으로 O(MlogM)이 나오고, 각 간선에 대해서 모두 따지면 O(M * MlogM) 이라는 TLE 받기 좋은 시간 복잡도가 나오게 됩니다. 그래서 조금은 다른 방법을 생각해봅니다.

여기서 크루스칼 알고리즘을 M번 돌리는게 문제라는 걸 알 수 있습니다. 그래서 이를 최소한으로, 가능하면 딱 한번만 돌려서 판단해보면 좋을 것 같습니다. 여기서 파생해서 이러한 아이디어를 생각할 수 있습니다.

  • 일단 MST를 구한다.
  • 각 간선에 대해서
    • 해당 간선이 MST에 포함되면 추가 연산이 불필요하다.
    • 하지만, MST에 포함되지 않으면 추가 연산이 필요하다.

추가 연산이 MST 다시 구하는거면 너무 무겁습니다. 하지만 이렇게 생각해볼 수 있습니다. 결국 해당 간선이 MST에 해당하지 않으면 해당하는 정점이 해당 간선으로는 연결이 되지 않았다는 것입니다. 다른 경로가 MST에 있다는 것입니다. 각 정점을 u, v라 하겠습니다.

여기서 MST는 기본적으로 트리이기 때문에 u, v 사이 경로가 유일함이 보장됩니다. 그렇기에 MST의 u, v 사이 경로 중 최대가 되는 값을 우리가 원하는 간선으로 대치 해주면 원하는 간선을 포함한 MST가 되는 것입니다. 유일한 경로을 끊어주고 새로운 유일한 경로를 만들어주는 느낌입니다.

즉, MST를 그리고 각 정점 사이 간선의 최대값 을 저장해주면 됩니다. 그리고 이를 저장할 때 유용한 자료구조가 바로 희소 배열이고, 최소 공통 조상(LCA) 알고리즘을 통해서 저장, 조회를 진행하면 됩니다.

코드에서는 특별히 해당 간선이 MST인지 아닌지를 별도로 구분하진 않았습니다. 어차피 MST에 포함된다면 본인을 끊었다 다시 본인으로 연결하는 꼴이기 때문입니다.

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

typedef long long ll;
typedef pair<int, ll> pil;

struct Edge {
    int u, v; ll w;
    Edge(int _u, int _v, ll _w) : u(_u), v(_v), w(_w) {}
};

const int MAX = 200'001;
int n, m, p[MAX], depth[MAX], pp[20][MAX];
ll wmax[20][MAX];
vector<pil> g[MAX], mst[MAX];
vector<Edge> edges, sorted_edges;

int find_p(int x) { return p[x] == x ? x : p[x] = find_p(p[x]); }
void union_p(int x, int y) { (x > y ? p[x] : p[y]) = (x > y ? y : x); }
bool comp(Edge e1, Edge e2) { return e1.w < e2.w; }

void dfs(int cur, int d) {
    depth[cur] = d;
    for(auto nxt : mst[cur]) {
        if(depth[nxt.first] == 0) {
            pp[0][nxt.first] = cur;
            wmax[0][nxt.first] = nxt.second;
            dfs(nxt.first, d + 1);
        }
    }
}

ll lca(int u, int v) {
    ll res = 0;
    if(depth[u] < depth[v]) swap(u, v);
    for(int i = 19; i >= 0; --i) {
        if(depth[pp[i][u]] >= depth[v]) {
            res = max(res, wmax[i][u]);
            u = pp[i][u];
        }
    }
    if(u == v) return res;
    for(int i = 19; i >= 0; --i) {
        if(pp[i][u] != pp[i][v]) {
            res = max(res, max(wmax[i][u], wmax[i][v]));
            u = pp[i][u];
            v = pp[i][v];
        }
    }
    res = max(res, max(wmax[0][u], wmax[0][v]));
    return res;
}

int main() {
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);

    cin >> n >> m;
    int u, v; ll w;
    for(int i = 0; i < m; ++i) {
        cin >> u >> v >> w;
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
        Edge e(u, v, w);
        edges.emplace_back(e);
        sorted_edges.emplace_back(e);
    }

    ll mst_val = 0;
    int cnt = 0;
    sort(sorted_edges.begin(), sorted_edges.end(), comp);
    for(int i = 1; i <= n; ++i) p[i] = i;
    for(auto& e : sorted_edges) {
        int u = find_p(e.u), v = find_p(e.v);
        if (u != v) {
            union_p(u, v);
            mst_val += e.w;
            mst[e.u].emplace_back(e.v, e.w);
            mst[e.v].emplace_back(e.u, e.w);
            if (++cnt == n - 1) break;
        }
    }

    dfs(1, 1);
    for(int i = 1; i < 20; ++i) {
        for(int j = 1; j <= n; ++j) {
            pp[i][j] = pp[i - 1][pp[i - 1][j]];
            wmax[i][j] = max(wmax[i - 1][j], wmax[i - 1][pp[i - 1][j]]);
        }
    }
    
    for (auto& e : edges) cout << mst_val - lca(e.u, e.v) + e.w << '\n';

    return 0;
}

1626 - 두 번째로 작은 스패닝 트리(D4)

위 문제를 풀었다면, 이 문제 아이디어도 어렵지 않게 생각할 수 있습니다. 다만 몇가지 실수할 여지가 있습니다.

이 문제는 SMST(Second Minimum Spanning Tree)를 구하면 되는 것이고 핵심은 MST보다 크다 에 있습니다. 그렇기에 MST를 만들 수 없거나 아니면 MST 밖에 못만드는 상황에서는 SMST를 만들 수 없습니다.

처음 생각(오답)

위 그래프와 MST 문제처럼 각 간선을 택했을 때 최댓값을 변경하는 방식으로 진행했습니다. 그래서 나온 코드는 아래와 같습니다.

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;
typedef pair<int, int> pii;
typedef long long ll;

struct Edge {
    int u, v, w;
    Edge(int _u, int _v, int _w) : u(_u), v(_v), w(_w) {}
};

const int MAX = 50'001;

vector<pii> g[MAX], mst[MAX];
vector<Edge> edges;

int n, m;
int par[MAX], dep[MAX], pp[17][MAX], wmax[17][MAX];

int find_p(int x) { return (par[x] == x ? x : par[x] = find_p(par[x])); }
void union_p(int x, int y) { (x > y ? par[x] : par[y]) = (x > y ? y : x); }
bool comp(Edge e1, Edge e2) { return e1.w < e2.w; }

int kruskal() {
    int res = 0, cnt = 0;
    sort(edges.begin(), edges.end(), comp);
    for(int i = 1; i <= n; ++i) par[i] = i;
    for(auto e : edges) {
        int pu = find_p(e.u);
        int pv = find_p(e.v);
        if(pu != pv) {
            union_p(pu, pv);
            res += e.w;
            mst[e.u].emplace_back(e.v, e.w);
            mst[e.v].emplace_back(e.u, e.w);
            if(++cnt == n - 1) return res;
        }
    }
    return -1;
}

void dfs(int cur, int d) {
    dep[cur] = d;
    for(auto nxt : mst[cur]) {
        if(dep[nxt.first] == 0) {
            pp[0][nxt.first] = cur;
            wmax[0][nxt.first] = nxt.second;
            dfs(nxt.first, d + 1);
        }
    }
}

int lca(int u, int v) {
    int res = 0;
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 16; i >= 0; --i) {
        if(dep[pp[i][u]] >= dep[v]) {
            res = max(res, wmax[i][u]);
            u = pp[i][u];
        }
    }
    if(u == v) return res;
    for(int i = 16; i >= 0; --i) {
        if(pp[i][u] != pp[i][v]) {
            res = max(res, max(wmax[i][u], wmax[i][v]));
            u = pp[i][u];
            v = pp[i][v];
        }
    }
    res = max(res, max(wmax[0][u], wmax[0][v]));
    return res;
}

int main() {
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin >> n >> m;
    int u, v, w;
    for(int i = 0; i < m; ++i) {
        cin >> u >> v >> w;
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
        Edge e(u, v, w);
        edges.emplace_back(e);
    }
    
    int mst_val = kruskal();
    if(mst_val == -1) {
        cout << mst_val << '\n';
        return 0;
    }

    dfs(1, 1);
    for(int i = 1; i < 17; ++i) {
        for(int j = 1; j <= n; ++j) {
            pp[i][j] = pp[i - 1][pp[i - 1][j]];
            wmax[i][j] = max(wmax[i - 1][j], wmax[i - 1][pp[i - 1][j]]);
        }
    }
    
    ll smst = 1e18;
    for(auto e : edges) {
        int l = lca(e.u, e.v);
        if(l < e.w) smst = min(smst, ll(mst_val - l + e.w));
    }

    cout << (smst == 1e18 ? -1 : smst) << '\n';
    return 0;
}

최종(정답)

하지만, 위 코드로 제출할 경우 반례가 있습니다. 바로 최댓값에 해당하는 간선의 가중치와 포함하려는 간선의 가중치가 동일한 경우 입니다. 이런 경우는 SMST가 아닌 그냥 새로운 MST를 구하는 꼴입니다.

이를 해결하기 위해서는 바로 최댓값 뿐만 아니라 두번째로 큰 간선도 같이 저장해줍니다. 각각 wmaxswmax로 저장을 해줍니다.

두번째로 큰 값 구하는게 조금 귀찮은데, 여러 방법이 있지만 다음과 같이 일일이 순회하면서 구현했습니다.

// {최대, 두번째 최대} 구하기
// vector<int> nxt - 기존 최대, 기존 두번째 최대, 새로운 값들 등을 넣는 인자
pii calc(vector<int> nxt) {
    int mm = -1, sm = -1;
    for(int x : nxt) {
        if(x == -1) continue;
        if(mm < x) {
            sm = mm;
            mm = x;
        } else if(sm < x && x < mm) {
            sm = x;
        }
    }
    return { mm, sm };
}

그래서 다음과 같이 희소 배열 초기화할 때나

    dfs(1, 1);
    for(int i = 1; i < 17; ++i) {
        for(int j = 1; j <= n; ++j) {
            pp[i][j] = pp[i - 1][pp[i - 1][j]];
            pii res = calc({ wmax[i - 1][j], swmax[i - 1][j], wmax[i - 1][pp[i - 1][j]], swmax[i - 1][pp[i - 1][j]] });
            wmax[i][j] = res.first;
            swmax[i][j] = res.second;
        }
    }

아니면 lca 구할 때

pii lca(int u, int v) {
    int mm = -1, sm = -1;
    pii res;
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 16; i >= 0; --i) {
        if(dep[pp[i][u]] >= dep[v]) {
            res = calc({ mm, sm, wmax[i][u], swmax[i][u] });
            mm = res.first;
            sm = res.second;
            u = pp[i][u];
        }
    }
    if(u == v) return { mm, sm };
    for(int i = 16; i >= 0; --i) {
        if(pp[i][u] != pp[i][v]) {
            res = calc({ mm, sm, wmax[i][u], swmax[i][u], wmax[i][v], swmax[i][v] });
            mm = res.first;
            sm = res.second;
            u = pp[i][u];
            v = pp[i][v];
        }
    }
    res = calc({ mm, sm, wmax[0][u], swmax[0][u], wmax[0][v], swmax[0][v] });
    mm = res.first;
    sm = res.second;
    return { mm, sm };
}

이 때 중간중간 넣어주면서 구현해주면 됩니다.

최종 코드

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;
typedef pair<int, int> pii;
typedef long long ll;

struct Edge {
    int u, v, w;
    Edge(int _u, int _v, int _w) : u(_u), v(_v), w(_w) {}
};

const int MAX = 50001;

bool check[200000];
vector<pii> g[MAX], mst[MAX];
vector<Edge> edges;

int n, m;
int par[MAX], dep[MAX], pp[17][MAX], wmax[17][MAX], swmax[17][MAX];

int find_p(int x) { return (par[x] == x ? x : par[x] = find_p(par[x])); }
void union_p(int x, int y) { (x > y ? par[x] : par[y]) = (x > y ? y : x); }
bool comp(Edge e1, Edge e2) { return e1.w < e2.w; }

int kruskal() {
    int res = 0, cnt = 0;
    sort(edges.begin(), edges.end(), comp);
    for(int i = 1; i <= n; ++i) par[i] = i;
    for(int i = 0; i < m; ++i) {
        Edge e = edges[i];
        int pu = find_p(e.u);
        int pv = find_p(e.v);
        if(pu != pv) {
            union_p(pu, pv);
            res += e.w;
            mst[e.u].emplace_back(e.v, e.w);
            mst[e.v].emplace_back(e.u, e.w);
            check[i] = true;
            if(++cnt == n - 1) return res;
        }
    }
    return -1;
}

void dfs(int cur, int d) {
    dep[cur] = d;
    for(auto nxt : mst[cur]) {
        if(dep[nxt.first] == 0) {
            pp[0][nxt.first] = cur;
            wmax[0][nxt.first] = nxt.second;
            swmax[0][nxt.first] = -1;
            dfs(nxt.first, d + 1);
        }
    }
}

pii calc(vector<int> nxt) {
    int mm = -1, sm = -1;
    for(int x : nxt) {
        if(x == -1) continue;
        if(mm < x) {
            sm = mm;
            mm = x;
        } else if(sm < x && x < mm) {
            sm = x;
        }
    }
    return { mm, sm };
}

pii lca(int u, int v) {
    int mm = -1, sm = -1;
    pii res;
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 16; i >= 0; --i) {
        if(dep[pp[i][u]] >= dep[v]) {
            res = calc({ mm, sm, wmax[i][u], swmax[i][u] });
            mm = res.first;
            sm = res.second;
            u = pp[i][u];
        }
    }
    if(u == v) return { mm, sm };
    for(int i = 16; i >= 0; --i) {
        if(pp[i][u] != pp[i][v]) {
            res = calc({ mm, sm, wmax[i][u], swmax[i][u], wmax[i][v], swmax[i][v] });
            mm = res.first;
            sm = res.second;
            u = pp[i][u];
            v = pp[i][v];
        }
    }
    res = calc({ mm, sm, wmax[0][u], swmax[0][u], wmax[0][v], swmax[0][v] });
    mm = res.first;
    sm = res.second;
    return { mm, sm };
}

int main() {
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin >> n >> m;
    int u, v, w;
    for(int i = 0; i < m; ++i) {
        cin >> u >> v >> w;
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
        Edge e(u, v, w);
        edges.emplace_back(e);
    }
    
    int mst_val = kruskal();
    if(mst_val == -1) {
        cout << mst_val << '\n';
        return 0;
    }

    dfs(1, 1);
    for(int i = 1; i < 17; ++i) {
        for(int j = 1; j <= n; ++j) {
            pp[i][j] = pp[i - 1][pp[i - 1][j]];
            pii res = calc({ wmax[i - 1][j], swmax[i - 1][j], wmax[i - 1][pp[i - 1][j]], swmax[i - 1][pp[i - 1][j]] });
            wmax[i][j] = res.first;
            swmax[i][j] = res.second;
        }
    }

    ll smst = 1e18;
    int mm, sm;
    for(int i = 0; i < m; ++i) {
        if(check[i]) continue;
        Edge e = edges[i];
        pii res = lca(e.u, e.v);
        mm = res.first, sm = res.second;
        if(e.w > mm) smst = min(smst, ll(mst_val - mm + e.w));
        else if(e.w > sm && sm != -1) smst = min(smst, ll(mst_val - sm + e.w));
    }

    cout << (smst == 1e18 ? -1 : smst) << '\n';
    return 0;
}
profile
안녕하세요! 언제나 탐구하고 공부하는 개발자, 주재완입니다.

0개의 댓글