[ BOJ / Python ] 1717번 집합의 표현

황승환·2022년 3월 6일
0

Python

목록 보기
227/498


이번 문제는 union-find 알고리즘을 통해 해결하였다. 처음에는 문제에서 주어진 그대로 집합 자료형을 사용하여 합집합을 만들어 관리하도록 하였지만 시간 초과가 발생하였다.

n, m=map(int, input().split())
result=[{i} for i in range(n+1)]
for _ in range(m):
    cul, a, b=map(int, input().split())
    if cul==0:
        if a==b:
            continue
        a_set={}
        b_set={}
        rmv=[]
        for i in range(len(result)):
            if {a} & result[i]:
                a_set=result[i]
                rmv.append(result[i])
            if {b} & result[i]:
                b_set=result[i]
                rmv.append(result[i])
        result.append(a_set | b_set)
        result.remove(rmv[0])
        result.remove(rmv[1])
    if cul==1:
        chk=False
        for i in range(len(result)):
            if {a, b} & result[i]=={a, b}:
                chk=True
                break
            else:
                continue
        if chk:
            print('YES')
        else:
            print('NO')

0부터 n까지의 집합을 만들어 result에 담아두고, 0이 입력되면 a, b가 포함된 집합을 찾은 후, 이 둘의 합집합을 result에 저장하고, a, b가 담겨있던 집합을 지우는 형태로 작성하였고, 1이 입력되면 {a, b}와 result[i]의 교집합이 {a, b}일 경우, YES를 출력하도록, 끝까지 이 조건에 만족하지 못하면 NO를 출력하도록 하였다. 답은 제대로 반환되었지만, 시간 초과가 발생하였기에 다른 알고리즘을 적용해야겠다고 생각했다.

그래서 다른 알고리즘을 찾아보던 중 union-find 알고리즘을 알게 되었다. find 함수에서는 해당 수의 루트를 찾도록 하고, union 함수에서는 find에서 찾은 a, b의 루트를 비교하여, 같을 경우 그대로 두고, 다를 경우, 두 집합을 합치도록 한다. 그래서 0이 입력되면 union함수를 호출하고, 1이 입력되면 find(a)와 find(b)의 값을 비교하여 같을 경우에 YES, 아닐 경우에 NO를 출력하도록 하였다.

  • n, m을 입력받는다.
  • parent 리스트에 1부터 n까지 담는다.
  • find 함수를 cur을 인자로 갖도록 선언한다.
    -> 만약 cur이 parent[cur]과 같을 경우,
    --> cur을 반환한다.
    -> parent[cur]find(parent[cur])의 재귀 호출 반환값으로 저장한다.
    -> parent[cur]을 반환한다.
  • union 함수를 a, b를 인자로 갖도록 선언한다.
    -> a를 find(a)의 반환값으로 선언한다.
    -> b를 find(b)의 반환값으로 선언한다.
    -> 만약 a와 b가 같을 경우,
    --> 함수를 종료한다.
    -> 만약 a가 b보다 작을 경우,
    --> parent[b]를 a로 갱신한다.
    -> 그 외의 경우,
    --> parent[a]를 b로 갱신한다.
  • m번 반복하는 for문을 돌린다.
    -> cul, a, b를 입력받는다.
    -> 만약 cul이 0일 경우, union(a, b)를 호출한다.
    -> 만약 cul이 1일 경우,
    --> 만약 find(a)find(b)가 같을 경우, YES를 출력한다.
    --> 그 외의 경우, NO를 출력한다.
import sys
input=sys.stdin.readline
sys.setrecursionlimit(10**9)
n, m=map(int, input().split())
parent=[i for i in range(n+1)]
def find(cur):
    if cur==parent[cur]:
        return cur
    parent[cur]=find(parent[cur])
    return parent[cur]
def union(a, b):
    a=find(a)
    b=find(b)
    if a==b:
        return
    if a<b:
        parent[b]=a
    else:
        parent[a]=b
for _ in range(m):
    cul, a, b=map(int, input().split())
    if cul==0:
        union(a, b)
    elif cul==1:
        if find(a)==find(b):
            print('YES')
        else:
            print('NO')

profile
꾸준함을 꿈꾸는 SW 전공 학부생의 개발 일기

0개의 댓글