초기에 n+1개의 집합 {0}, {1}, {2}, ... ,{n}이 있다. 여기에 합집합 연산과, 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산을 수행하려고 한다.
집합을. 표현하는 프로그램을 작성하시오.
첫째 줄에 n, m이 주어진다. m은 입력으로 주어지는 연산의 개수이다. 다음 m개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는 a가 포함되어 있는 집합과, b가 포함되어 있는 집합을 합친다는 의미이다. 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산은 1 a b의 형태로 입력이 주어진다. 이는 a와 b가 같은 집합에 포함되어 있는지를 확인하는 연산이다.
1로 시작하는 입력에 대해서 a와 b가 같은 집합에 포함되어 있으면 "YES" 또는 "yes"를, 그렇지 않다면 "NO" 또는 "no"를 한 줄에 하나씩 출력한다.
7 8
0 1 3
1 1 7
0 7 6
1 7 1
0 3 7
0 4 2
0 1 1
1 1 1
NO
NO
YES
해당 문제는 어떤 자료구조를 쓰느냐에 따라서 프로그램이 돌아가는 시간이 달라질 수 있는 시간 복잡도에 관련된 문제라고 생각했다.
일단 문제에 적힌 대로 집합으로 표현을 하고자 했는데, 그렇게 하는거보다 딕셔너리로 표현하는 게 낫다고 생각을 했다. Key, Value 모두 초기에는 0 : 0, 1 : 1 ..., 7 : 7(n이 7이라고 가정할 때)라고 선언을 한다. 알고리즘이 흘러가는 방식을 정리하면 다음과 같다.
n = 7(가정)
dict
0 : 0, 1 : 1, 2 : 2, ... 7 : 7
0 2 3 (가정)
key 2의 value = 2
key 3의 value = 3
둘 중의 최소 값은 2
value가 2 또는 3인 key들을 모두 2로 update
0 : 0, 1 : 1, 2 : 2, 3 : 2, ... , 7 : 7
... (반복)
코드로 적으면 다음과 같다.
a, b = map(int, input().split())
s = {}
for i in range(a+1) :
s[i] = i # dict 선언
for i in range(b) :
c, d, e = map(int, input().split())
if c == 0 :
m = min(s[d],s[e]) # value 값 중 작은 것들 고름
for j in range(a+1) :
if s[j] == s[d] or s[j] == s[e] : # 합집합으로 만드는 과정
s[j] = m # 합집합으로 만들면서 value 값들 중 작은 것들로 update
else :
if s[d] == s[e] : # value가 같으면 같은 집합 안이라는 말
print('YES')
else :
print('NO')
이렇게 하니 아쉽게도 시간 초과가 떴다.
아마 다음 부분이 오래 걸려서 그런 것 같다.
for j in range(a+1) :
if s[j] == s[d] or s[j] == s[e] : # 합집합으로 만드는 과정
s[j] = m # 합집합으로 만들면서 value 값들 중 작은 것들로 update
딕셔너리를 모두 찾아가며 일일히 업데이트해서 그런 거 같다.
그러면 head가 되는, 그러니까 가장 대장(?)이 될 수 있는 value의 key 값을 찾아서 비교하고 합집합인지 여부를 알면 되는데, 이 때 학교 수업 때 배웠던 Union Find
가 생각났다.
사실 내가 구현했던(시간 초과한) 방법도 Union Find 방법이지만, Union Find의 문제 자체가 Union할 때의 시간이 오래 걸리는 것이었다.
QU와 WQU의 차이는 트리의 깊이(depth) 차이이다.
Quick-Union은 앞서 말했던 대장(?)(부모 노드)을 만들어서 합집합의 id(component id)를 체크해서 합집합인지 여부를 찾을 수 있는 방식이다.
이거는 tree 형식으로 갖다 붙이는데, 경우에 따라 한 노드 방향으로 계속 tree가 밑에 붙여지면 깊이가 너무 깊어져 찾는데에 시간이 아주 오래 걸릴 수도 있다.
이 점을 보완한 것이 WQU이다.
그래서 Union을 할 때, 합집합을 만들때도 해당 component들의, 즉 합집합의 사이즈를 비교해서 큰 합집합의 밑에 작은 합집합이 올 수 있도록 만든다.
list를 활용해서 코드를 짰다.
생각해보니 index가 key 값이 될 수 있고, value가 list의 값이 될 수 있기 때문이다.
그래서 코드는 다음과 같다.
a, b = map(int, input().split())
ids = []
size = [] # 합집합의 size 알기
for idx in range(a+1) :
ids.append(idx)
size.append(1)
def root(i) :
while i != ids[i] :
i = ids[i]
return i # 합집합의 id 찾기 -> 부모 노드 찾기
def connected(p, q) :
return root(p) == root(q) # root 함수의 return 값이 같으면 같은 합집합
def union(p,q) :
id1, id2 = root(p), root(q)
if id1 == id2 :
return
if size[id1] <= size[id2] :
ids[id1] = id2
size[id2] += size[id1]
else :
ids[id2] = id1
size[id1] += size[id2]
for i in range(b) :
c, d, e = map(int, input().split())
if c == 0 :
union(d,e)
else :
if connected(d,e) == True:
print('YES')
else :
print('NO')
# print(ids) -> 디버깅용