서로 중복되지 않는 부분 집합들로 나눠진 원소들에 대한 정보를 저장하고 조작하는 자료구조
즉, 다수의 노드들 중에 연결된 노드를 찾거나 합칠 때 사용하는 알고리즘
유니온 파인드 구현 방식은 1. 배열 2. 트리 방식이 있다
def init(n):
data = list(range(n))
def find(index):
return data[index]
def union(x, y):
x = find(x)
y = find(y)
if x == y:
return
for i in range(n):
if find(i) == y:
data[i] = x
하지만 배열 방식은 시간 복잡도가 크다는 단점이 있다.
대표적으로 unio-by-size 형태를 보자
data = [ -1 for _ in range(n) ]
def find(idx):
value = data[idx]
if value < 0:
return idx
return find(value) #최상위의 루트 노드 반환
def union(x, y):
x = find(x)
y = find(y)
if x== y:
return
if data[x] < data[y]:
data[x] += data[y]
data[y] = x
else:
data[y] += data[x]
data[x] = y
즉, 루트 노드의 값은 계속해서 음수가 되고 나머지는 루트 노드의 값을 가리키게 된다.
그럼 예제 문제를 풀어보자
문제
초기에 {0}, {1}, {2}, ... {n} 이 각각 n+1개의 집합을 이루고 있다. 여기에 합집합 연산과, 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산을 수행하려고 한다.
집합을 표현하는 프로그램을 작성하시오.
입력
첫째 줄에 n(1 ≤ n ≤ 1,000,000), m(1 ≤ m ≤ 100,000)이 주어진다. m은 입력으로 주어지는 연산의 개수이다. 다음 m개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는 a가 포함되어 있는 집합과, b가 포함되어 있는 집합을 합친다는 의미이다. 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산은 1 a b의 형태로 입력이 주어진다. 이는 a와 b가 같은 집합에 포함되어 있는지를 확인하는 연산이다. a와 b는 n 이하의 자연수 또는 0이며 같을 수도 있다.
출력
1로 시작하는 입력에 대해서 한 줄에 하나씩 YES/NO로 결과를 출력한다. (yes/no 를 출력해도 된다)
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n, m = map(int, input().split())
arr = [ i for i in range(n+1)]
def find(idx):
global arr
value = arr[idx]
if value == idx:
return idx
#부모 노드 갱신
arr[idx] = find(value)
return arr[idx]
def union(a, b):
global arr
#루트 노드를 찾는다
x = find(a)
y = find(b)
if x == y:
return
if x < y:
arr[y] = x
else:
arr[x] = y
for _ in range(m):
a, b, c = map(int, input().split())
if a == 0:
union(b, c)
if a == 1:
if find(b) == find(c):
print("YES")
else:
print("NO")
#print(arr)