이번 문제는 union-find 알고리즘을 통해 해결하였다. 처음에는 문제에서 주어진 그대로 집합 자료형을 사용하여 합집합을 만들어 관리하도록 하였지만 시간 초과가 발생하였다.
n, m=map(int, input().split())
result=[{i} for i in range(n+1)]
for _ in range(m):
cul, a, b=map(int, input().split())
if cul==0:
if a==b:
continue
a_set={}
b_set={}
rmv=[]
for i in range(len(result)):
if {a} & result[i]:
a_set=result[i]
rmv.append(result[i])
if {b} & result[i]:
b_set=result[i]
rmv.append(result[i])
result.append(a_set | b_set)
result.remove(rmv[0])
result.remove(rmv[1])
if cul==1:
chk=False
for i in range(len(result)):
if {a, b} & result[i]=={a, b}:
chk=True
break
else:
continue
if chk:
print('YES')
else:
print('NO')
0부터 n까지의 집합을 만들어 result에 담아두고, 0이 입력되면 a, b가 포함된 집합을 찾은 후, 이 둘의 합집합을 result에 저장하고, a, b가 담겨있던 집합을 지우는 형태로 작성하였고, 1이 입력되면 {a, b}와 result[i]의 교집합이 {a, b}일 경우, YES를 출력하도록, 끝까지 이 조건에 만족하지 못하면 NO를 출력하도록 하였다. 답은 제대로 반환되었지만, 시간 초과가 발생하였기에 다른 알고리즘을 적용해야겠다고 생각했다.
그래서 다른 알고리즘을 찾아보던 중 union-find 알고리즘을 알게 되었다. find 함수에서는 해당 수의 루트를 찾도록 하고, union 함수에서는 find에서 찾은 a, b의 루트를 비교하여, 같을 경우 그대로 두고, 다를 경우, 두 집합을 합치도록 한다. 그래서 0이 입력되면 union함수를 호출하고, 1이 입력되면 find(a)와 find(b)의 값을 비교하여 같을 경우에 YES, 아닐 경우에 NO를 출력하도록 하였다.
parent[cur]
과 같을 경우,parent[cur]
을 find(parent[cur])
의 재귀 호출 반환값으로 저장한다.parent[cur]
을 반환한다.find(a)
의 반환값으로 선언한다.find(b)
의 반환값으로 선언한다.parent[b]
를 a로 갱신한다.parent[a]
를 b로 갱신한다.union(a, b)
를 호출한다.find(a)
와 find(b)
가 같을 경우, YES를 출력한다.import sys
input=sys.stdin.readline
sys.setrecursionlimit(10**9)
n, m=map(int, input().split())
parent=[i for i in range(n+1)]
def find(cur):
if cur==parent[cur]:
return cur
parent[cur]=find(parent[cur])
return parent[cur]
def union(a, b):
a=find(a)
b=find(b)
if a==b:
return
if a<b:
parent[b]=a
else:
parent[a]=b
for _ in range(m):
cul, a, b=map(int, input().split())
if cul==0:
union(a, b)
elif cul==1:
if find(a)==find(b):
print('YES')
else:
print('NO')