[Python] Union Find? Disjoint Set은 무엇인가 ?

Yg999999999·2024년 3월 12일
1

알고리즘

목록 보기
3/4
post-thumbnail

🤔Union Find?

유니온 파인드, 혹은 분리 집합이라 불리는 알고리즘은 두 원소가 같은 집합에 속하는지 판별해주는 알고리즘이다. 두 원소를 같은 집합으로 분류하는 merge(합집합) 연산과 원소가 어느 집합에 속해있는지 알려주는 find 연산이 사용된다. 간단한 예제를 보면서 알고리즘을 알아보자.

✅예제

1717번: 집합의 표현

문제 조건

1 ≤ n ≤ 100,0000

1≤ m ≤ 100,000

0 ≤ a≤ b ≤ n

a, b는 정수

알고리즘 개요

n+1 개의 원소가 있고 m개의 연산이 실행된다. 연산은 두 가지로 두 원소가 포함된 집합을 합집합 하는 연산과 두 원소가 같은 집합에 속해있는지 확인하는 연산이 있다. 집합을 표현하기 위해서 n+1개의 배열을 선언한다. 배열의 값은 해당 원소(인덱스)가 무슨 집합에 속해있는지 가리킨다.

집합을 가리키기 위해서는 집합의 대표 번호를 정한다. 이때 해당 집합의 가장 작은 숫자를 집합의 대표 번호로 사용한다. 단, 자신이 집합의 대표 번호라면 -1로 선언한다.

예를 들면, 집합 Set1 = { 1, 2, 3, 5} , Set2 = { 0, 4} 가 존재한다고 가정해보자. 아래의 사진과 같이 arr 배열로 집합을 표기할 수 있을 것이다.

알고리즘 구현

해당 알고리즘의 구현을 위해 두 가지 연산이 필요하다.

  1. find 연산 > 원소가 어느 집합에 속해있는지 알려주는 연산.
  2. merge 연산 > 두 원소를 같은 집합으로 분류하는 연산.

find 연산

def find(node):

	if arr[node] == -1:
			return node
	
	arr[node] = find(arr[node])
	return arr[node]

자신이 속한 집합을 찾는 함수이다. 만약 집합이 S1 = {1,2,3,5} 라면 find(2)는 1을 반환해야 할 것이다. find는 2번 원소가 해당한 집합의 번호(1,2,3,5 중 가장 작은 수)를 알려주어야 하기 때문이다.

find에서 가장 핵심적인 코드는 arr[node] = find(arr[node]) 이다. 재귀함수라 의미를 파악하기 어렵기 때문에 처음 유니온 파인드를 공부하면 이 부분에서 의문을 가질 수 있을 것이다. 하지만 문제를 몇 번 풀다 보면 금방 이해할 것이다. 당장 이해가 안 된다면 꼭 외우도록 하자 !

merge 연산

def merge(a,b):

		a = find(a)
		b = find(b)
		
		if a==b: # 이미 두 원소가 같은 집합에 속해있기 때문에 합집합 연산을 수행하지 않는다.
			return
			
		big_node = max(a,b)
		small_node = min(a,b)
		
		arr[big_node] = small_node

두 원소의 집합 번호를 비교하고 큰 노드가 작은 노드를 가리키도록 설계하면 된다.

전체 코드

import sys
sys.setrecursionlimit(10**6)
input=sys.stdin.readline

def find(node):
    if arr[node]==-1:
        return node
    else:
        arr[node]=find(arr[node])
        return arr[node]

def merge(a,b):
    a=find(a)
    b=find(b)

    if a==b: #이미 a,b 같은 집합에 소속되어있다
        return

    # 노드를 잇는 과정
    big_node=max(a,b)
    small_node=min(a,b)
    arr[big_node]=small_node

def is_union(a,b):
    a=find(a)
    b=find(b)

    if a==b:
        print("YES")
    else:
        print("NO")

n,m = map(int,input().split())
arr=[-1]*(n+1)

for i in range(m):
    op,a,b=map(int,input().split())
    if op == 0:
        merge(a,b)
    else:
        is_union(a,b)
profile
BackEnd developer

0개의 댓글