https://www.acmicpc.net/problem/15481
시간 2초, 메모리 512MB
input :
output :
조건 :
MST를 만드는데 특정 간선을 추가해야 하는 문제이다.
최소 스패닝 트리의 경우에는 크루스칼 + union-find를 통해 구현한다.
특정 간선을 포함하기 위해선 빠르게 찾는 방법으로 LCA를 사용할 수 있다.
특정 노드 (a, b)가 있을 때 이 둘이 연결되어 있는 모든 간선 중 가장 큰 값을 찾는 것으로 그 중 제일 작은 MST를 만들 수 있다.
그러면 가장 중요한 문제로 어떻게 가장 큰 값을 찾냐? 를 해결해야 한다.
sparse table을 사용하는 LCA를 통해 이를 해결할 수 있다.
특정 노드만을 저장하는 것이 아닌, [노드 + 지금까지 간선들 중 가장 큰 값]을 저장하게 할 수 있다.
즉, sparse table을 초기화 할 때 해당 log 범위 위의 노드를 찾을 때 간선도 같이 체킹을 하도록 해라.
for log in range(1, 21):
for node in range(1, n + 1):
next_node, next_weight = l_parent[node][log - 1]
l_parent[node][log] = [l_parent[next_node][log - 1][0], max(next_weight, l_parent[next_node][log - 1][1])]
l_parent[현재 노드][2^log에 위치한 노드] = [해당 노드의 넘버, 여기까지 최대값의 간선]
이 값을 통해 LCA의 리턴값이 노드를 주는 것이 아닌 최대값을 리턴하도록 하면 됨.
import sys
from heapq import heappop, heappush
def find(node):
if node == parent[node]:
return node
parent[node] = find(parent[node])
return parent[node]
def union(a, b):
parent_a = find(a)
parent_b = find(b)
if parent_a > parent_b:
parent[parent_a] = parent_b
else:
parent[parent_b] = parent_a
def dfs(node, depth):
visit[node] = 1
level[node] = depth
for next_node, weight in graph[node]:
if visit[next_node]:
continue
l_parent[next_node][0] = [node, weight]
dfs(next_node, depth + 1)
def set_parent():
dfs(1, 0)
for log in range(1, 21):
for node in range(1, n + 1):
next_node, next_weight = l_parent[node][log - 1]
l_parent[node][log] = [l_parent[next_node][log - 1][0], max(next_weight, l_parent[next_node][log - 1][1])]
def lca(high, low):
ret = 0
if level[high] > level[low]:
high, low = low, high
for log in range(20, -1, -1):
if level[low] - level[high] >= (1 << log):
ret = max(ret, l_parent[low][log][1])
low = l_parent[low][log][0]
if high == low:
return ret
for log in range(20, -1, -1):
if l_parent[low][log][0] != l_parent[high][log][0]:
ret = max(ret, max(l_parent[low][log][1], l_parent[high][log][1]))
low = l_parent[low][log][0]
high = l_parent[high][log][0]
return max(ret, l_parent[low][0][1], l_parent[high][0][1])
n, m = map(int, sys.stdin.readline().split())
edge, data, parent, graph = [], [], [i for i in range(n + 1)], [[] for i in range(n + 1)]
l_parent, visit, level = [[[0, 0] for _ in range(21)] for _ in range(n + 1)], [0] * (n + 1), [0] * (n + 1)
mst= 0
for _ in range(m):
u, v, w = list(map(int, sys.stdin.readline().split()))
data.append((u, v, w))
heappush(edge, (w, u, v))
while edge:
w, u, v = heappop(edge)
if find(u) != find(v):
union(u, v)
graph[u].append((v, w))
graph[v].append((u, w))
mst += w
set_parent()
for u, v, w in data:
cost = lca(u, v)
print(mst - cost + w)
다른 분의 풀이를 통해 "맞힌 사람" 카테고리를 확인해서 참고하려 했다.
근데 다른 것보다, lca 마지막 리턴 값에서 max()를 제대로 쓰지 못한 부분이 주요했다.
마지막에 max(ret, low의 최댓값, high의 최댓값)중 가장 큰 값을 가져가도록 하자.