union find는 서로소 집합 자료구조이다.
루트노드가 같은 노드끼리 그룹으로 묶어주는 용도이다.
1 ~ 8번 노드가 있다고 할 때 간선정보를 다음과 같이 설정하겠다.
(1 2) (2 3) (7 8) (4 6) (5 6)
import sys
input = sys.stdin.readline
def find_parent(p, x):
# 경로 압축
# 한번 실행되고나면 타고타고 들어가서 찾는게 아닌
# 바로 루트를 찾고 반환 받게 된다
if p[x] != x:
p[x] = find_parent(p, p[x])
return p[x]
# 두 노드를 같은 집합으로 처리하기
def union_parent(p, a, b):
# 각자 자신의 노드의 루트 노드를 찾는다
a = find_parent(p, a)
b = find_parent(p, b)
# 더 큰 루트노드를 가진쪽이 작은쪽을 부모로 삼는다
if a < b:
p[b] = a
else:
p[a] = b
p = [i for i in range(9)]
edge = [(1, 2), (3, 2), (7, 8), (6, 4), (5, 6)]
print(f'그룹화 전: {p[1:]}')
for a, b in edge:
union_parent(p, a, b)
for i in range(1, len(p)):
find_parent(p, i)
print(f'그룹화 후: {p[1:]}')
-결과-
그룹화 전: [1, 2, 3, 4, 5, 6, 7, 8]
그룹화 후: [1, 1, 1, 4, 4, 4, 7, 7]
-중간과정-
[1, 1, 3, 4, 5, 6, 7, 8]
[1, 1, 1, 4, 5, 6, 7, 8]
[1, 1, 1, 4, 5, 6, 7, 7]
[1, 1, 1, 4, 5, 4, 7, 7]
[1, 1, 1, 4, 4, 4, 7, 7]
- 자신의 인덱스를 값으로 가지는 배열 선언
= 자기 자신을 루트로 설정- 연결된 노드끼리는 공통된 루트를 가르키도록 union
union find를 사용하면 몇개의 그룹이 있는지, 해당 그룹의 노드는 몇개인지 알 수 있다.
사이클의 존재 유무도 알 수 있는데 방법은 다음과 같다.
간선정보를 확인하고 해당 노드끼리 묶어주기전에 노드들의 root를 먼저 확인하는데 이미 같다면 사이클이 존재하는 것이다.
예를 들어 위의 그림에서 추가로 (1, 3)의 간선 정보가 들어왔다고 한다면 이미 1, 3의 루트 노드가 같은 상태에서 묶어주게 되므로 사이클이 발생하게 되는것이다.
최소한의 비용으로 모든 노드를 연결 시키는 알고리즘이다.
그리디 알고리즘으로 분류된다.
방법은 다음과 같다.
- 간선 데이터를 비용에 따라 오름차순으로 정렬한다.
- 간선을 하나씩 확인하면서 사이클을 발생시키는지 확인한다.
-> 사이클이 발생하면 최소 신장 트리에 포함시키지 않는다.
-> 사이클이 발생하지 않으면 최소 신장 트리에 포함시킨다.- 모든 간선에 대해 2번 과정을 진행해준다.
import sys
input = sys.stdin.readline
def find_parent(p, x):
if p[x] != x:
p[x] = find_parent(p, p[x])
return p[x]
def union_parent(p, a, b):
a = find_parent(p, a)
b = find_parent(p, b)
if a < b:
p[b] = a
else:
p[a] = b
node_cnt = 7
p = [i for i in range(node_cnt + 1)]
edge = [(1, 2, 29), (1, 5, 75), (2, 3, 35), (2, 6, 34), (3, 4, 7), (4, 6, 23), (4, 7, 13), (5, 6, 53), (6, 7, 25)]
ans = 0
count = 0
edge.sort(key = lambda x: x[-1])
for node_a, node_b, cost in edge:
if find_parent(p, node_a) != find_parent(p, node_b):
union_parent(p, node_a, node_b)
ans += cost
count += 1
# 간선의 개수가 (전체 노드 개수 - 1)이 되면 모든 노드가 연결된거임
if count == node_cnt - 1:
break
print(ans)
-결과-
159