[BOJ 17398] - 통신망 분할 (분리 집합, 오프라인 쿼리, C++, Python)

보양쿠·2023년 9월 20일
0

BOJ

목록 보기
195/260
post-custom-banner

BOJ 17398 - 통신망 분할 링크
(2023.09.20 기준 P5)

문제

N개의 통신 탑과 통신 탑을 잇는 연결 M개가 주어진다.

연결을 끊을 때, 통신망이 두 개로 나뉘면 나뉜 통신망에 속한 통신 탑들의 개수의 곱이 비용으로 들고, 나뉘지 않으면 비용은 0이다.

Q개의 끊을 연결의 비용을 차례대로 출력

알고리즘

분리 집합 및 쿼리를 오프라인으로 처리

풀이

분리 집합은 두 집합을 연결하기 위해 쓰이지만, 이 문제는 반대로 연결을 끊는 문제이다.
우리는 분리 집합이라는 연결하는 기술을 갖고 있고, 이 문제는 연결의 반대인 분할을 필요로 하며, 연결은 분할의 반댓말이다.
즉, 연결 상태에서 하나씩 끊는 것을 반대로 하면, 분할 상태에서 하나씩 연결하는 것이다.

그러므로, 주어지는 쿼리를 반대로 처리하면 된다.
쿼리로 주어지지 않는 연결들은 따로 체크해서, 쿼리를 반대로 처리하기 전에 먼저 주어지지 않았던 연결들을 추가하면 된다.

비용은 끊고 나서의 두 집합의 크기의 곱이었으니, 반대로 처리할 때에는 union하기 전 두 집합의 크기의 곱이 비용이 된다고 생각하면 된다.

코드

  • C++
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

const int MAXN = 1e5 + 1;

int pa[MAXN], sz[MAXN];

// union-find
int find(int u){
    if (pa[u] != u) pa[u] = find(pa[u]);
    return pa[u];
}

void merge(int u, int v){
    u = find(u); v = find(v);
    if (u < v){
        pa[v] = u; sz[u] += sz[v];
    }
    else if (v < u){
        pa[u] = v; sz[v] += sz[u];
    }
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    int N, M, Q; cin >> N >> M >> Q;
    pii edges[M]; for (int i = 0; i < M; i++) cin >> edges[i].x >> edges[i].y;

    // 각 간선이 쿼리로 주어짐을 체크하면서 쿼리 입력 받기
    bool on[M]; fill(on, on + M, false);
    vector<int> queries;
    for (int A; Q; Q--){
        cin >> A;
        on[--A] = true;
        queries.push_back(A);
    }

    // 쿼리로 주어지지 않은 간선들을 먼저 union
    iota(pa, pa + N + 1, 0);
    fill(sz, sz + N + 1, 1);
    for (int i = 0; i < M; i++) if (!on[i]) merge(edges[i].x, edges[i].y);

    // 쿼리가 주어진 역방향으로 union하면서 각 집합의 크기를 곱해 더하자.
    ll result = 0;
    while (!queries.empty()){
        int i = queries.back(); queries.pop_back();
        int u = find(edges[i].x), v = find(edges[i].y);
        if (u == v) continue; // 같은 집합이면 continue
        result += sz[u] * sz[v];
        merge(u, v);
    }

    cout << result;
}
  • Python
import sys; input = sys.stdin.readline

# union-find
def find(u):
    if pa[u] != u:
        pa[u] = find(pa[u])
    return pa[u]

def union(u, v):
    u = find(u)
    v = find(v)
    if u < v:
        pa[v] = u
        sz[u] += sz[v]
    elif v < u:
        pa[u] = v
        sz[v] += sz[u]

N, M, Q = map(int, input().split())
edges = [tuple(map(int, input().split())) for _ in range(M)]

# 각 간선이 쿼리로 주어짐을 체크하면서 쿼리 입력 받기
on = [False] * M
queries = []
for _ in range(Q):
    A = int(input()) - 1
    on[A] = True
    queries.append(A)

# 쿼리로 주어지지 않은 간선들을 먼저 union
pa = [i for i in range(N + 1)]
sz = [1] * (N + 1)
for i in range(M):
    if not on[i]:
        union(*edges[i])

# 쿼리가 주어진 역방향으로 union하면서 각 집합의 크기를 곱해 더하자.
result = 0
while queries:
    u, v = map(lambda x: find(x), edges[queries.pop()])
    if u == v: # 같은 집합이면 continue
        continue
    result += sz[u] * sz[v]
    union(u, v)

print(result)
profile
GNU 16 statistics & computer science
post-custom-banner

0개의 댓글