처음에는 다익스트라로 풀려고했다가 실패했다. 다익스트라는 한 정점에서 다른 모든 정점들에 대해 최소 비용을 알려주는 것이다. 여기서는 전체 간선이 최소화되어야한다.
import heapq as hq
def solution(n, costs):
answer = 0
from_to = [[] for _ in range(n)]
visited = [False] * n
priority = []
for a, b, cost in costs:
from_to[a].append((b, cost))
from_to[b].append((a, cost))
hq.heappush(priority, (0, 0)) # (cost, start)
while False in visited:
cost, start = hq.heappop(priority)
if visited[start]:
continue
visited[start] = True
answer += cost
for end, cost in from_to[start]:
if visited[end]:
continue
else:
hq.heappush(priority, (cost, end))
print(f"start: {start}, end: {end}")
return answer
프림을 이용한 풀이다. 일단 한 곳에서 시작해서 트리를 유지하면서 트리를 확장해나간다.
parent = {}
rank = {}
# 정점을 독립적인 집합으로 만든다.
def make_set(v):
parent[v] = v
rank[v] = 0
# 해당 정점의 최상위 정점을 찾는다.
def find(v):
if parent[v] != v:
parent[v] = find(parent[v])
return parent[v]
# 두 정점을 연결한다.
def union(v, u):
root1 = find(v)
root2 = find(u)
if root1 != root2:
# 짧은 트리의 루트가 긴 트리의 루트를 가리키게 만드는 것이 좋다.
if rank[root1] > rank[root2]:
parent[root2] = root1
else:
parent[root1] = root2
if rank[root1] == rank[root2]:
rank[root2] += 1
def kruskal(n, costs):
for v in range(n):
make_set(v)
mst = []
edges = [(costs[i][2], costs[i][0], costs[i][1]) for i in range(len(costs))]
edges.sort()
for edge in edges:
weight, v, u = edge
if find(v) != find(u):
union(v, u)
mst.append(edge)
return mst
def solution(n, costs):
answer = 0
temp = kruskal(n, costs)
for i in range(len(temp)):
answer += temp[i][0]
return answer
https://goldfishhead.tistory.com/51
무조건 비용이 제일 적은 것부터 선택하고, 다만 그렇게 트리를 만들 때 사이클이 생기면 안되기 때문에 사이클 생성 여부를 판단하기 위해 유니온 파인드 알고리즘이 사용된다.