안녕하세요. 오늘은 두번째로 작은 스패닝 트리를 만들 거예요.

문제

https://www.acmicpc.net/problem/1626

아이디어

일단 MST(최소 스패닝 트리)를 만듭니다.
그리고 아래 내용을 해줍니다.

두 정점 a와 b를 잇는 간선을 추가한다. 그럼과 동시에 MST상에서 a와 b를 잇는 간선중 가장 가중치가 큰 간선을 제거한다. 이때 추가하는 간선과 제거하는 간선의 가중치의 차가 최소일수록 좋다.

말로는 쉽습니다. 하지만 디테일한 부분까지 들어가면 굉장히 골치아파집니다.
일단 MST를 만드는것은 크루스칼 알고리즘으로 커버가 가능합니다. 하지만 a와 b를 잇는 간선중 가장 가중치가 큰 간선은 어떻게 찾을까요? 바로 LCA입니다. 이를 이용해서 해결할 수 있습니다.

만약 두 값이 같으면 어떨까요? 그러면 문제의 조건에 의해서 두번째로 작은 최소 스패닝 트리가 되지 못합니다. 따라서 어떤 간선을 제거하면 그 두 정점을 잇는 간선중 가장 가중치가 큰 간선과 두번째로 가중치가 큰 간선을 가져와서 비교를 해주면 됩니다.

그러면 -1을 출력하는 경우는 어떤 경우일까요?
바로 위에서 다룬 "두 값이 같으면"에 조건에서 한발 더 가서 "모든 경로의 값이 같으면"입니다. 모든 경로, 모든 값이 같으면 이렇게 됩니다. 또한 이미 트리였어도 이런 현상이 발생합니다. 하지만 여기서 주의할 점이 있습니다. 바로 그래프가 제대로 만들어지지 않을수도 있다는 것입니다. 즉, 모든 정점이 연결되어있지 않을 수 있습니다. 그래서 크루스칼 알고리즘을 돌린 후에 꼼꼼히 체크를 해주어야합니다.

소스코드

#include <iostream>
#include <vector>
#include <algorithm>
#define ll long long
using namespace std;

ll parent[50505] = { 0 };
void UF_init()
{
    for (ll i = 1; i <= 50000; i++)
        parent[i] = i;
}
ll find(ll x)
{
    if (x == parent[x]) return x;
    return parent[x] = find(parent[x]);
}
void Union(ll x, ll y)
{
    parent[find(x)] = find(y);
}
bool same(ll x, ll y)
{
    return find(x) == find(y);
}

ll min_value;
bool cmp(pair <pair <ll, ll>, ll> A, pair <pair <ll, ll>, ll> B)
{
    return A.second < B.second;
}
vector <pair <pair <ll, ll>, ll> > Kruskal(vector <pair <pair <ll, ll>, ll> > v)
{
    vector <pair <pair <ll, ll>, ll> > Ans;

    UF_init();
    sort(v.begin(), v.end(), cmp);
    ll N = v.size();
    for (ll i = 0; i < N; i++)
    {
        if (same(v[i].first.first, v[i].first.second)) continue;
        Ans.push_back(v[i]);
        min_value += v[i].second;
        Union(v[i].first.first, v[i].first.second);
    }
    return Ans;
}

pair <ll, ll> Combine(ll a, ll b, ll c, ll d)
{
    vector <ll> v;
    v.push_back(a);
    v.push_back(b);
    v.push_back(c);
    v.push_back(d);

    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
    sort(v.begin(), v.end(), greater<>());

    pair <ll, ll> res;
    if (v.size() == 1) res = { v[0],v[0] };
    else if (v.size() == 0) res = { -1,-1 };
    else res = { v[0],v[1] };

    return res;
}

vector <pair <ll, ll> > graph[50505];
ll LCA_dpt[50505] = { 0 }, LCA_parent[50505][22] = { 0 }, LCA_Max[50505][22] = { 0 }, LCA_SecMax[50505][22] = { 0 };
void DFS(ll node, ll up)
{
    LCA_dpt[node] = LCA_dpt[up] + 1;
    LCA_parent[node][0] = up;
    for (ll i = 1; i <= 20; i++)
    {
        LCA_parent[node][i] = LCA_parent[LCA_parent[node][i - 1]][i - 1];
        pair <ll, ll> temp = Combine(LCA_Max[node][i - 1], LCA_SecMax[node][i - 1], LCA_Max[LCA_parent[node][i - 1]][i - 1], LCA_SecMax[LCA_parent[node][i - 1]][i - 1]);
        LCA_Max[node][i] = temp.first;
        LCA_SecMax[node][i] = temp.second;
    }

    for (auto next : graph[node])
    {
        if (next.first != up)
        {
            LCA_Max[next.first][0] = LCA_SecMax[next.first][0] = next.second;
            DFS(next.first, node);
        }
    }
}
ll LCA(ll x, ll y)
{
    if (LCA_dpt[x] < LCA_dpt[y]) swap(x, y);
    for (ll i = 20; i >= 0; i--)
    {
        if (LCA_dpt[LCA_parent[x][i]] >= LCA_dpt[y])
            x = LCA_parent[x][i];
    }
    if (x == y) return x;
    for (ll i = 20; i >= 0; i--)
    {
        if (LCA_parent[x][i] != LCA_parent[y][i])
        {
            x = LCA_parent[x][i];
            y = LCA_parent[y][i];
        }
    }
    return LCA_parent[x][0];
}
pair <ll, ll> EvaluateMax(ll up, ll down)
{
    pair <ll, ll> Ans = { -1,-1 };
    for (ll i = 20; i >= 0; i--)
    {
        if (LCA_dpt[LCA_parent[down][i]] >= LCA_dpt[up])
        {
            Ans = Combine(Ans.first, Ans.second, LCA_Max[down][i], LCA_SecMax[down][i]);
            down = LCA_parent[down][i];
        }
    }
    return Ans;
}

int main()
{
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    ll N, M, i, a, b, c;
    vector <pair <pair <ll, ll>, ll> > v;

    cin >> N >> M;
    for (i = 1; i <= M; i++)
    {
        cin >> a >> b >> c;
        v.push_back({ {a,b},c });
    }

    vector <pair <pair <ll, ll>, ll> > v2 = Kruskal(v);
    if (v2.size() != N - 1)
    {
        cout << -1;
        return 0;
    }
    for (i = 0; i < N - 1; i++)
    {
        graph[v2[i].first.first].push_back({ v2[i].first.second,v2[i].second });
        graph[v2[i].first.second].push_back({ v2[i].first.first,v2[i].second });
    }

    DFS(1, 0);
    ll Ans = 2e11;
    for (i = 0; i < M; i++)
    {
        ll lca = LCA(v[i].first.first, v[i].first.second);
        pair <ll, ll> first = EvaluateMax(lca, v[i].first.first), second = EvaluateMax(lca, v[i].first.second);
        pair <ll, ll> val = Combine(first.first, first.second, second.first, second.second);

        if (val.first >= 0 && val.first != v[i].second) Ans = min(Ans, min_value - val.first + v[i].second);
        else if (val.second >= 0 && val.second != v[i].second) Ans = min(Ans, min_value - val.second + v[i].second);
    }

    if (Ans == 2e11) cout << -1;
    else cout << Ans;
}


감사합니다.

0개의 댓글