Union-find (disjoint-set) 알고리즘, [백준 1717] 집합의 표현

alllloha·2021년 9월 28일
0

Union-find 알고리즘을 이해하고 예제 문제를 풀어본다

Union-find (disjoint-set) 알고리즘

서로 중복되지 않는 부분 집합들로 나눠진 원소들에 대한 정보를 저장하고 조작하는 자료구조
즉, 다수의 노드들 중에 연결된 노드를 찾거나 합칠 때 사용하는 알고리즘

  • 예시

Union-find 구현

  • 초기화 : N개의 원소가 각각의 집합에 포함되어 있도록 초기화
  • union : 두 원소가 주어질 때, 이들이 속한 두 집합을 하나로 합친다
  • find : 어떤 원소가 주어질 때, 원소가 속한 집합 반환

유니온 파인드 구현 방식은 1. 배열 2. 트리 방식이 있다

  • 배열 ( 시간복잡도 O(n) )
    원소의 크기만큼 배열을 초기화한다.
    union : 배열의 모든 원소를 순회하면서 하나의 번호를 나머지 하나로 교체한다.
    아래는 대략 코드의 흐름
    def init(n):
    	data = list(range(n))
    
    def find(index):
    	return data[index]
   
    def union(x, y):
    	x = find(x)
        y = find(y)
        
        if x == y:
        	return
            
        for i in range(n):
        	if find(i) == y:
            	data[i] = x

하지만 배열 방식은 시간 복잡도가 크다는 단점이 있다.

  • 트리 (log n)
    트리 구조는 세 가지 형태로 나뉜다.
  1. union-by-size : 원소의 수가 적은 집합을 많은 집합의 하위트리로 추가
  2. union-by-height : 트리의 높이가 작은 집합을 큰 집합의 서브트리로 추가
  3. path comprehension : find 연산 비용을 낮춘다.

대표적으로 unio-by-size 형태를 보자

  • 주어진 원소의 개수만큼 사용하지 않는 값을 생성한다
data = [ -1 for _ in range(n) ]
  • 루트 노드의 인덱스를 찾는다
def find(idx):
	value = data[idx]
    if value < 0:
    	return idx
    return find(value) #최상위의 루트 노드 반환
  • 루트 노드의 인덱스가 다르다면 리스트의 값이 더 낮은 것(size가 큰 것)을 찾아서 큰것을 더해준다
  • 작은 걸 큰 것의 인덱스로 바꾼다
def union(x, y):
	x = find(x)
    	y = find(y)
    
      if x== y:
          return
      if data[x] < data[y]:
          data[x] += data[y]
          data[y] = x
      else:
          data[y] += data[x]
          data[x] = y

즉, 루트 노드의 값은 계속해서 음수가 되고 나머지는 루트 노드의 값을 가리키게 된다.
그럼 예제 문제를 풀어보자

[백준 1717] 집합의 표현 (python)

문제
초기에 {0}, {1}, {2}, ... {n} 이 각각 n+1개의 집합을 이루고 있다. 여기에 합집합 연산과, 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산을 수행하려고 한다.
집합을 표현하는 프로그램을 작성하시오.

입력
첫째 줄에 n(1 ≤ n ≤ 1,000,000), m(1 ≤ m ≤ 100,000)이 주어진다. m은 입력으로 주어지는 연산의 개수이다. 다음 m개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는 a가 포함되어 있는 집합과, b가 포함되어 있는 집합을 합친다는 의미이다. 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산은 1 a b의 형태로 입력이 주어진다. 이는 a와 b가 같은 집합에 포함되어 있는지를 확인하는 연산이다. a와 b는 n 이하의 자연수 또는 0이며 같을 수도 있다.

출력
1로 시작하는 입력에 대해서 한 줄에 하나씩 YES/NO로 결과를 출력한다. (yes/no 를 출력해도 된다)

  • 아이디어
  1. index 번호를 가진 리스트 arr을 만든다.
  2. find 함수로 해당 노드의 루트 노드를 찾아낸다
    • 여기서 부모 노드를 갱신해주는 과정이 필요하다 (시간 단축을 위해!!!!)
  3. union 함수로 루트 노드가 더 작은 집합에 큰 집합을 합치는 과정을 추가한다
  • 코드
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

n, m = map(int, input().split())

arr = [ i for i in range(n+1)]

def find(idx):
    global arr
    value = arr[idx]

    if value == idx:
         return idx
    
    #부모 노드 갱신
    arr[idx] = find(value)
    return arr[idx]

def union(a, b):
    global arr
    #루트 노드를 찾는다
    x = find(a)
    y = find(b)

    if x == y:
        return
    
    if x < y:
        arr[y] = x
    else:
        arr[x] = y

for _ in range(m):
    a, b, c = map(int, input().split())
    if a == 0:
        union(b, c)
    if a == 1:  
        if find(b) == find(c):
            print("YES")
        else:
            print("NO")
    #print(arr)
  • 결과

참고자료 - https://brownbears.tistory.com/460

0개의 댓글