gmlwid9405님 블로그 참고
그레이트쪼님 블로그 참고
두 노드가 같은 그래프에 속하는 지 확인하기 위한 알고리즘. 기본적으로 x 가 어떤 집합에 포함되는 지 확인하기 위한
Find
와 집합 x와 y를 합치는Union
으로 구성된다.
서로 중복되지 않는 부분 집합들 로 이루어진 원소들에 대한 정보를 저장하고 조작하는 자료구조
즉, 공통요소가 없는 상호 배타적인 부분집합들로 나눠진 원소들에 대한 자료구조이다. (다른 말로, 서로소 집합)
Disjoint Set 을 표현할 때 사용하는 알고리즘
배열로 구현한다면
Array[i] : i번 원소가 속하는 집합의 번호(즉, 루트 노드의 번호)
Array[i] = i
와 같이 각자 다른 집합 번호로 초기화트리로 구현한다면
같은 집합 = 하나의 트리, 즉 집합 번호 = 루트 노드
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n,m = list(map(int,input().split()))
parent = [i for i in range(n+1)]
def union(a,b):
a = find(a)
b = find(b)
if a < b:
parent[b] = a
else:
parent[a] = b
def find(a):
if parent[a] == a:
return a
else:
return find(parent[a])
for _ in range(m):
command, a, b = list(map(int, input().split()))
# a의 집합과 b의 집합을 합친다
if command == 0:
union(a,b)
# a와 b가 같은 집합인지
else:
print('YES') if find(a) == find(b) else print('NO')
기본.
각 트리에 대해 높이(rank)를 기억해두고,
union 시 두 트리의 rank 가 다르면 높이가 작은 트리를 높이가 트리의 루트에 붙인다. -> 길이는 높은 쪽의 depth
높이가 같은 트리를 합칠 때에는 한 쪽 트리 높이를 1 증가시키고 다른 트리를 해당 트리에 붙인다. -> rank 에 대힌 집합을 따로 만든다.
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n,m = list(map(int,input().split()))
parent = [i for i in range(n+1)]
rank = [0 for _ in range(n+1)]
def union(a,b):
a = find(a)
b = find(b)
if a == b:
return
# 높이가 낮은 트리를 높은 트리 밑에 넣는다
if rank[a] < rank[b]:
parent[a] = b
else:
parent[b] = a
if rank[a] == rank[b]:
rank[a]+=1
def find(a):
if parent[a] == a:
return a
else:
return find(parent[a])
for _ in range(m):
command, a, b = list(map(int, input().split()))
# a의 집합과 b의 집합을 합친다
if command == 0:
union(a,b)
# a와 b가 같은 집합인지
else:
print('YES') if find(a) == find(b) else print('NO')
find를 실행한 노드에서 거쳐간 노드를 root 에 direct 로 연결
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n,m = list(map(int,input().split()))
parent = [i for i in range(n+1)]
def union(a,b):
a = find(a)
b = find(b)
if a < b:
parent[b] = a
else:
parent[a] = b
def find(a):
if parent[a] == a:
return a
else:
parent[a] = find(parent[a])
return parent[a]
for _ in range(m):
command, a, b = list(map(int, input().split()))
# a의 집합과 b의 집합을 합친다
if command == 0:
union(a,b)
# a와 b가 같은 집합인지
else:
print('YES') if find(a) == find(b) else print('NO')
관련문제
1. BOJ 1717 집합의 표현: https://www.acmicpc.net/problem/1717