알고리즘 유형 : 유니온 파인드
풀이 참고 없이 스스로 풀었나요? : 학습
https://www.acmicpc.net/problem/1717
import sys
input = sys.stdin.readline
def find(x):
if parent[x] < 0:
return x
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
x = find(x)
y = find(y)
if x == y:
return
if parent[x] < parent[y]:
parent[y] = x
elif parent[x] > parent[y]:
parent[x] = y
else:
parent[y] -= 1
parent[x] = y
return
n, m = map(int, input().split())
parent = [-1]*(n+1)
for _ in range(m):
cmd, a, b = map(int, input().split())
if cmd:
if find(a) == find(b):
print("YES")
else:
print("NO")
else:
union(a, b)
풀이 요약
weighted union find를 학습하기 좋은 문제이다.
우선 초기에 그래프는, 각 노드에 순서대로 자연수가 노드 번호로 붙는다고 가정할 때, 그 노드 번호의 인덱스에 루트 or 부모 노드 값이 들어있는 리스트로 표현한다.
단, 자기 자신이 루트노드인 경우는 값이 음수이고, 절댓값은 자신이 속해 있는 트리의 높이를 뜻한다. 그 외의 경우에는 양수가 들어있고 이는 부모노드를 의미한다.
예를 들어 그래프가 1부터 7까지의 노드가 있고 모두 단일 노드 트리일 때, 리스트는
parent = [-1, -1, -1, -1, -1, -1, -1, -1]이 된다.
값이 음수이면 자기 자신이 루트 노드라는 뜻이고 절댓값은 트리의 높이를 의미한다.
find는 입력으로 들어온 특정 노드
가 속한 트리의 루트 노드를 반환하는 함수이다.
parent값이 음수라면 본인이 루트 노드이므로 본인을 반환해주고, 양수라면 자신의 부모 노드에 대해 재귀적으로 루트 노드를 찾는다.
이 때, 최적화를 목적으로
return find(parent[x])가 아닌,
parent[x] = find(parent[x])
return parent[x]
로 작성해준다. (parent 갱신)
union은 두 노드가 속한 집합을 합치는 함수이다.
입력으로 들어온 두 노드 x, y의 루트 노드를 찾고, 이 둘이 같으면 이미 같은 집합이므로 return.
다르면 서로 다른 집합이므로 합쳐준다.
x = find(x)
y = find(y)
로 루트 노드를 찾음과 동시에 x, y를 루트 노드로 설정해준다.
이 후, x, y의 parent값의 절댓값은 본인이 속한 트리의 높이를 나타내는데, 트리의 높이가 작은 것을 큰 쪽으로 합치는게 트리의 높이를 최소화하는 방법이다.
만약 b 트리를 a 트리에 합칠 때, b 트리 입장에서는 본인의 트리에서 위에 a 라는 루트 노드가 하나 더 생긴 것이므로 트리 높이가 1 늘어난다. 만약 이 때 b 트리의 높이가 a 트리 높이보다 크다면 합치고 난 후의 트리 높이는 b+1이 된다.
그러나 a 트리를 b 트리에 합치게 되면 a+1은 b 트리 높이 값보다 같거나 작으므로, 트리 높이는 유지가 된다. 이런 이유로 위의 방법으로 트리를 합친다.
만약 트리 높이가 서로 같다면, 아무 트리에나 합친 다음, 합침 기준 트리의 parent 값에 1 더 해준다.
배운 점, 어려웠던 점