이렇게 모든 노드(원소)를 연결해야하고 연결하기 위한 최소 비용을 구하는 문제는 높은 확률로 MST (최소 스파닝 트리, 최소 신장 트리)로 해결하라는 문제이다.
그래프 상의 모든 노드가 사이클 없이 연결된 부분 그래프를 최소 연결 부분 그래프, Spanning Tree라고 한다. 이러한 신장 트리는 아래와 같은 특징을 가지고 있다.
이런 신장트리 중에서 모든 간선의 비용 합이 최소가 되는 신장 트리를 최소 신장 트리(Minimum Spanning Tree)라고 하고 최소 신장 트리는 그래프에 단 하나만 존재한다.
그래프가 주어졌을때 최소 신장 트리는 Kruskal 알고리즘 또는 Prim 알고리즘으로 찾을 수 있다.
탐욕적인 방법을 이용하여 모든 정점을 최소 비용으로 연결하는 방법
크루스칼 알고리즘은 쉽게 말해 지금 어떤 신장 트리가 만들어졌는지에는 관심이 없고 가지고 있는 간선들 중에서 최소 비용인 간선만을 선택해 나가는 방법이다. 이 때 사이클을 만들지 않는 간선들을 선택해야 한다.
시작 정점에서부터 출발하여 신장트리 집합을 단계적으로 확장해나가는 방법
이 방법은 Kruskal 알고리즘과는 다르게 현재 만들어져 있는 스파닝 트리에 인접한 정점들 중 최소 비용 간선을 선택해 나가는 방법이다.
각 Kruskal Algorithm, Prim Algorithm의 과정 예는 위키피디아의 문서를 통해 확인할 수 있다.
최소 신장 트리를 구현하기 위해 크루스칼 알고리즘을 사용할 때, 사이클을 만들지 않고 모든 노드를 방문하기 위해 Union-Find를 사용한다. Union-Find 알고리즘의 경우 그 노드의 부모 노드 번호를 기억하며 해당 부모 노드로 쭉 따라 올라가면서 트리의 루트 노드를 찾게 된다.
같은 트리에 속한, 즉 같은 집합에 속한 노드의 경우 같은 루트 노드를 가지게 된다. 만약 루트 노드를 찾았는데 다른 루트 번호를 가진다면 해당 집합에 속하지 않는 노드라는 뜻이다.
이 과정을 구현하기 위해선 Union과 Find가 필요하다.
즉, 쉽게 말해 Find로 루트 노드를 찾고 Union으로 두 트리를 병합해 나간다. 이런 식으로 사이클을 만들지 않으면서 결국 모든 노드를 트리 형태로 병합하면 그 트리가 곳 신장 트리가 된다.
트리 형태 자체가 사이클이 없는 그래프를 의미한다. 따라서 트리와 트리를 병합한다는 것은 곳 사이클을 만들지 않고 그래프를 연결한다는 뜻이다.
def find(n) :
if parent[n] != n :
parent[n] = find(parent[n])
return parent[n]
def union(a, b) :
a = find(a)
b = find(b)
if a != b :
parent[a] = b
이제 Kruskal Algorithm을 구현할 수 있다. 간선을 비용 순으로 정렬해준 후에 해당 간선을 방문하면서 해당 간선의 두 노드가 포함되어있는 트리의 루트 노드를 확인(Find)한다. 만약 두 노드의 루트 노드가 다르다면 해당 간선을 연결해주면서 두 트리를 병합(Union)해준다. 이런 식으로 결국 모든 노드가 한 트리에 포함하게 되면 해당 트리가 최소 신장 트리가 된다.
import sys
input = sys.stdin.readline
def find(n) :
if parent[n] != n :
parent[n] = find(parent[n])
return parent[n]
def union(a, b) :
a = find(a)
b = find(b)
if a > b : parent[b] = a
else : parent[a] = b
n = int(input())
stars = [tuple([i] + list(map(int,input().rstrip().split()))) for i in range(n)]
edges = []
for i in [1, 2, 3] :
sort_stars = sorted(stars, key = lambda x : x[i])
edges += [(abs(star1[i] - star2[i]), star1[0], star2[0]) for star1 ,star2 in zip(sort_stars[:-1], sort_stars[1:])]
edges.sort()
ans = 0
parent = list(range(n+1))
for dist, star1, star2 in edges :
if find(star1) != find(star2) :
union(star1, star2)
ans += dist
print(ans)