BOJ 15481 그래프와 MST

LONGNEW·2022년 1월 28일
0

BOJ

목록 보기
308/333

https://www.acmicpc.net/problem/15481
시간 2초, 메모리 512MB

input :

  • N M (2 ≤ N ≤ 200,000, N-1 ≤ M ≤ 200,000)
  • u v w (1 ≤ u, v ≤ n, u ≠ v, 1 ≤ w ≤ 10^9)

output :

  • 그 간선을 포함하는 최소 스패닝 트리의 가중치 합을 출력

조건 :

  • 그래프 G는 루프가 없고, 두 정점을 연결하는 간선은 최대 1개

MST를 만드는데 특정 간선을 추가해야 하는 문제이다.

다음 풀이.

  1. 최소 스패닝 트리 제작
  2. 특정 간선을 포함

최소 스패닝 트리의 경우에는 크루스칼 + union-find를 통해 구현한다.
특정 간선을 포함하기 위해선 빠르게 찾는 방법으로 LCA를 사용할 수 있다.
특정 노드 (a, b)가 있을 때 이 둘이 연결되어 있는 모든 간선 중 가장 큰 값을 찾는 것으로 그 중 제일 작은 MST를 만들 수 있다.
그러면 가장 중요한 문제로 어떻게 가장 큰 값을 찾냐? 를 해결해야 한다.

sparse table을 사용하는 LCA를 통해 이를 해결할 수 있다.
특정 노드만을 저장하는 것이 아닌, [노드 + 지금까지 간선들 중 가장 큰 값]을 저장하게 할 수 있다.

즉, sparse table을 초기화 할 때 해당 log 범위 위의 노드를 찾을 때 간선도 같이 체킹을 하도록 해라.

    for log in range(1, 21):
        for node in range(1, n + 1):

            next_node, next_weight = l_parent[node][log - 1]
            l_parent[node][log] = [l_parent[next_node][log - 1][0], max(next_weight, l_parent[next_node][log - 1][1])]

l_parent[현재 노드][2^log에 위치한 노드] = [해당 노드의 넘버, 여기까지 최대값의 간선]

이 값을 통해 LCA의 리턴값이 노드를 주는 것이 아닌 최대값을 리턴하도록 하면 됨.

import sys
from heapq import heappop, heappush

def find(node):
    if node == parent[node]:
        return node
    parent[node] = find(parent[node])
    return parent[node]

def union(a, b):
    parent_a = find(a)
    parent_b = find(b)

    if parent_a > parent_b:
        parent[parent_a] = parent_b
    else:
        parent[parent_b] = parent_a

def dfs(node, depth):
    visit[node] = 1
    level[node] = depth

    for next_node, weight in graph[node]:
        if visit[next_node]:
            continue

        l_parent[next_node][0] = [node, weight]
        dfs(next_node, depth + 1)

def set_parent():
    dfs(1, 0)

    for log in range(1, 21):
        for node in range(1, n + 1):

            next_node, next_weight = l_parent[node][log - 1]
            l_parent[node][log] = [l_parent[next_node][log - 1][0], max(next_weight, l_parent[next_node][log - 1][1])]

def lca(high, low):
    ret = 0
    if level[high] > level[low]:
        high, low = low, high

    for log in range(20, -1, -1):
        if level[low] - level[high] >= (1 << log):
            ret = max(ret, l_parent[low][log][1])
            low = l_parent[low][log][0]

    if high == low:
        return ret

    for log in range(20, -1, -1):
        if l_parent[low][log][0] != l_parent[high][log][0]:
            ret = max(ret, max(l_parent[low][log][1], l_parent[high][log][1]))
            low = l_parent[low][log][0]
            high = l_parent[high][log][0]

    return max(ret, l_parent[low][0][1], l_parent[high][0][1])


n, m = map(int, sys.stdin.readline().split())
edge, data, parent, graph = [], [], [i for i in range(n + 1)], [[] for i in range(n + 1)]
l_parent, visit, level = [[[0, 0] for _ in range(21)] for _ in range(n + 1)], [0] * (n + 1), [0] * (n + 1)
mst= 0

for _ in range(m):
    u, v, w = list(map(int, sys.stdin.readline().split()))

    data.append((u, v, w))
    heappush(edge, (w, u, v))

while edge:
    w, u, v = heappop(edge)

    if find(u) != find(v):
        union(u, v)
        graph[u].append((v, w))
        graph[v].append((u, w))

        mst += w

set_parent()
for u, v, w in data:
    cost = lca(u, v)
    print(mst - cost + w)

다른 분의 풀이를 통해 "맞힌 사람" 카테고리를 확인해서 참고하려 했다.
근데 다른 것보다, lca 마지막 리턴 값에서 max()를 제대로 쓰지 못한 부분이 주요했다.
마지막에 max(ret, low의 최댓값, high의 최댓값)중 가장 큰 값을 가져가도록 하자.

0개의 댓글