https://www.acmicpc.net/problem/16398
Minimum Spanning Tree(MST)는 문제를 보통 2가지 유형으로 푸는 듯 하다.
스패닝 트리가 된다는 조건 하에 heap 방식으로 풀이 => Prim 알고리즘인듯..?
Disjoint Set(분리집합) 문제와 유사하게 parent를 두어, 부모가 같은 지 확인하고 union 하는 방식으로 풀이 => Kruskal 알고리즘인듯..?
이번 문제의 경우에는 heap을 사용하였다. 왜냐하면 스패닝 트리가 완성되지 않는다는 말이 없어서, 이 부분에 대해 확인해줄 필요없이 heap 방식으로 사용하고자 했다.
내가 푼 Prim 알고리즘 버전의 코드
import sys
import heapq
input = sys.stdin.readline
n = int(input().strip())
# 방문했나?
visited = [False for i in range(n+1)]
# 처음 heap == cost 0, start node 1
heap = [[0,1]]
edges= []
# cnt는 현재 몇개의 노드를 방문했는지 세고, ans는 답
cnt, ans = 0,0
# 입력
for _ in range(n):
edges.append(list(map(int,input().split())))
# heap에 값이 있을 동안
while heap:
if cnt == n:
break
w, s = heapq.heappop(heap)
# 여태 방문하지 않은 노드만 선별
if not visited[s]:
visited[s] = True
cnt += 1
ans += w
for i in range(n):
if s-1 == i:
continue
heapq.heappush(heap, [edges[s-1][i],i+1])
print(ans)
그리고 친구로부터 받은 Kruskal 알고리즘 버전의 코드
import sys
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
def union(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
n = int(sys.stdin.readline().strip().split())
arr = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(n)]
parent = [i for i in range(n)]
edges = []
for i in range(1, n):
for j in range(i):
edges.append((arr[i][j], i, j))
edges.sort()
answer = 0
for cost, a, b in edges:
if find_parent(parent, a) != find_parent(parent, b):
union(parent, a, b)
answer += cost
print(answer)