파이썬 알고리즘 289번 | [백준 1717번] 집합의 표현 - 그래프(union & find : 합집합 찾기, 서로소 집합)

Yunny.Log ·2023년 1월 6일
0

Algorithm

목록 보기
294/318
post-thumbnail

289. 집합의 표현

1) 어떤 전략(알고리즘)으로 해결?

자료 구조
분리 집합 - union & find , Disjoint set (서로소 집합)

2) 코딩 설명

  • 유니온 파인드 구현 & 주석 설명

<내 풀이>


from collections import defaultdict
import sys
sys.setrecursionlimit(10**5)
n,m = map(int, sys.stdin.readline().rstrip().split())
parent = defaultdict(int) 

for i in range(n+1) : # 0~n : n+1개의 집합
    parent[i] = i # n원소는 n이라는 집합에 속함
    
def find(x) : # x의 조상(x가 속한 집합)을 찾는다.
    if parent[x]!=x : # 만약 x의 조상이 최종 조상이 아니라면
        parent[x] = find(parent[x]) # x의 조상의 조상을 x의 조상으로 삼는다
    return parent[x] # x의 조상이 최종 조상이 될 때 비로소 최종 조상을 반환한다. 

def union(a,b) : # a,b를 합집합으로 만들어준다. (a,b 조상을 통일해준다) 
    pa  = find(a) 
    pb = find(b)
    # 둘의 조상 중 둘 중 더 작은 조상으로 통일해준다.
    if pa < pb :
        parent[pb] = pa
    else : 
        parent[pa] = pb

for i in range(m) :

    o, a, b = map(int, sys.stdin.readline().rstrip().split())
    if o==0: # UNION : 합집합 
        union(a,b)

    else : # FIND : 두 원소가 같은 집합에 포함되어 있는지를 확인 
        if find(a) == find(b) :
            print("YES")
        else : 
            print("NO")

< 내 틀렸던 풀이, 문제점>

(1) 21프로 - 메모리 초과 !

from collections import defaultdict
import sys
n,m = map(int, sys.stdin.readline().rstrip().split())
dic = defaultdict(int)
val_dic = defaultdict(list) 
for i in range(n+1) : # 0~n : n+1개의 집합
    dic[i] = i # n원소는 n이라는 집합에 속함
    val_dic[i].append(i) # n이라는 집합에 속한 n원소

for i in range(m) :
    o, a, b = map(int, sys.stdin.readline().rstrip().split())
    if o==0: # 합집합 
        if dic[a]!=dic[b] : # 둘이 같은 집합 아닐 때만
            tmp = dic[b]
            val_dic[dic[a]]+=val_dic[dic[b]] #a 속한 집합에 원소로 b속한 집합 원소들 등록
            for j in val_dic[dic[b]] : 
                dic[j] = dic[a] # b 집합 애들 a가 속한 집합에 등록
            val_dic[tmp].clear # 비워주기 
    else : # 두 원소가 같은 집합에 포함되어 있는지를 확인 
        if dic[a]==dic[b] :
            print("YES")
        else : 
            print("NO")
  • 합집합인 애들을 모조리 쌩으로 더하다보니깐 자연스레 메모리초과가 나는 것이라고 생각합니다.
val_dic[dic[a]]+=val_dic[dic[b]]

(2) 5프로 - 시간 초과 !

from collections import defaultdict
import sys
n,m = map(int, sys.stdin.readline().rstrip().split())
dic = defaultdict(int) 
for i in range(n+1) : # 0~n : n+1개의 집합
    dic[i] = i # n원소는 n이라는 집합에 속함
    
for i in range(m) :
    o, a, b = map(int, sys.stdin.readline().rstrip().split())
    if o==0: # 합집합 
        if dic[a]!=dic[b] : # 둘이 같은 집합 아닐 때만
            tmp = dic[b]
            for j in range(len(dic)) :
                if dic[j]==tmp :
                    dic[j] = dic[a]
    else : # 두 원소가 같은 집합에 포함되어 있는지를 확인 
        if dic[a]==dic[b] :
            print("YES")
        else : 
            print("NO")
  • 당연하겠지만 이중 for문을 도는 것은 시간초과가 나게 되지요.

(3) Union-find 알고리즘에 대해 학습했습니다.

  • 그러나 틀렸습니다가 발생합니다. 왜일까요? 짐작이 잘 가지 않습니다.
from collections import defaultdict
import sys
n,m = map(int, sys.stdin.readline().rstrip().split())
parent = defaultdict(int) 
for i in range(n+1) : # 0~n : n+1개의 집합
    parent[i] = i # n원소는 n이라는 집합에 속함
    
def find(x) : # x의 조상(x가 속한 집합)을 찾는다.
    if parent[x]!=x : # 만약 x의 조상이 최종 조상이 아니라면
        parent[x] = find(parent[x]) # x의 조상의 조상을 x의 조상으로 삼는다
    return parent[x] # x의 조상이 최종 조상이 될 때 비로소 최종 조상을 반환한다. 

def union(a,b) : 
	# a,b를 합집합으로 만들어준다. (a,b 조상을 통일해준다) 
    pa  = find(a) 
    pb = find(b)
    # 둘의 조상 중 둘 중 더 작은 조상으로 통일해준다.
    if pa < pb :
        parent[b] = pa
    else : 
        parent[a] = pb

for i in range(m) :
    o, a, b = map(int, sys.stdin.readline().rstrip().split())
    if o==0: # UNION : 합집합 
        union(a,b)
    else : # FIND : 두 원소가 같은 집합에 포함되어 있는지를 확인 
        if find(a) == find(b) :
            print("YES")
        else : 
            print("NO")

union 함수를 잘못 구현했던 것입니다..

잘못 구현한 union 함수


def union(a,b) : # a,b를 합집합으로 만들어준다. (a,b 조상을 통일해준다) 
    pa  = find(a) 
    pb = find(b)
    # 둘의 조상 중 둘 중 더 작은 조상으로 통일해준다.
    if pa < pb :
        parent[b] = pa # 틀린 부분입니다.
    else : 
        parent[a] = pb

union에서 a 혹은 b의 집합 합치기 과정을 진행할 때

a 혹은 b의 parent를 갱신하는 것이 아닌, a 혹은 b의 parent의 parent를 갱신시켜주어야 합니다.


def union(a,b) : # a,b를 합집합으로 만들어준다. (a,b 조상을 통일해준다) 
    pa  = find(a) 
    pb = find(b)
    # 둘의 조상 중 둘 중 더 작은 조상으로 통일해준다.
    if pa < pb :
        parent[pb] = pa
    else : 
        parent[pa] = pb
  • 따라서 a 혹은 b의 parent (a,b가 속한 집합) 의 parent 를 갱신시켜주었습니다.

<반성 점>

  • union 함수 제대로 구현해야 한다는 점을 느꼈습니다.

union에서 a 혹은 b의 집합 합치기 과정을 진행할 때
a 혹은 b의 parent를 갱신하는 것이 아닌,
a 혹은 b의 parent의 parent를 갱신시켜주어야 합니다.

<배운 점>

#1

int find(int x) {

    if(x==parent[x]) return x;

    return find(parent[x]);

}
#2

int find(int x) {

    if(x==parent[x]) return x;

    return parent[x] = find(parent[x]);

}
  • #2의 find 함수는 x가 속한 트리의 루트 노드 값을 루트 노드에서 현재 노드까지 재귀적으로 반환합니다.
  • 이 때 parent[x] = find(parent[x]); 코드는 재귀적으로 반환하는 각 과정(경로)에 있는 모든 노드의 부모 노드를 루트 노드로 바꾸어 주는 것을 의미합니다.
  • #1의 find 함수는 매번 루트까지 재귀호출을 해서 루트 노드 값을 얻어오는 반면, #2의 find 함수는 한 번 호출하면, 그 경로에 있는 모든 노드의 부모 노드가 루트가 됩니다.
  • 따라서 그 다음부터 경로상에 있던 노드에 한해서 find 함수를 호출한다면, 루트노드를 찾기 위한 재귀 호출이 발생하지 않습니다.
  • 그 이유는 재귀 호출 과정에서 이미 부모 노드를 루트 노드로 갱신해놨기 때문입니다. 따라서 O(1) 시간에 자신이 속한 집합을 알아낼 수 있습니다. 이를 '경로 압축 최적화' 라고 합니다.

0개의 댓글