Union-Find 알고리즘

김가영·2021년 2월 21일
0

AlgorithmStudy

목록 보기
9/14
post-thumbnail

gmlwid9405님 블로그 참고
그레이트쪼님 블로그 참고

두 노드가 같은 그래프에 속하는 지 확인하기 위한 알고리즘. 기본적으로 x 가 어떤 집합에 포함되는 지 확인하기 위한 Find 와 집합 x와 y를 합치는 Union 으로 구성된다.

Disjoint Set

서로 중복되지 않는 부분 집합들 로 이루어진 원소들에 대한 정보를 저장하고 조작하는 자료구조
즉, 공통요소가 없는 상호 배타적인 부분집합들로 나눠진 원소들에 대한 자료구조이다. (다른 말로, 서로소 집합)

Union-Find

Disjoint Set 을 표현할 때 사용하는 알고리즘

연산

make-set(x)

  • 초기화
  • x를 유일한 원소로 하는 새로운 집합을 만든다

union(x,y)

  • 합하기
  • x가 속한 집합과 y가 속한 집합을 합친다. 즉, x와 y가 속한 두 집합을 합치는 연산

find(x)

  • 찾기
  • x가 속한 집합의 대표값(루트 노드 값)을 반환한다. 즉, x가 어떤 집합에 속해 있는 지 찾는 연산

Union-Find 알고리즘을 트리 구조로 구현하는 이유

배열로 구현한다면
Array[i] : i번 원소가 속하는 집합의 번호(즉, 루트 노드의 번호)

  • make-set(x) : Array[i] = i 와 같이 각자 다른 집합 번호로 초기화
  • union(x,y) : 배열의 모든 원소를 순회하면서 y의 집합번호를 x의 집합번호로 변경한다. 시간 복잡도 : O(n)
  • find(x) : 한 번에 x 가 속한 집합 번호를 찾는다. 시간 복잡도: O(1)

트리로 구현한다면
같은 집합 = 하나의 트리, 즉 집합 번호 = 루트 노드

  • make-set(x) : 각 노드는 모두 루트 노드이므로 N개의 루트 노드 생성 및 자기 자신으로 초기화
  • union(x,y) : x,y의 루트노드를 찾고, 다르면 y를 x의 자손으로 넣어 두 트리를 합친다. 시간복잡도 O(n) 보다 작다
  • find(x) : 노드의 집합 번호는 루트 노드이므로 루트 노드를 확인하여 같은 집합인지 확인한다. 시간복잡도 : 트리의 높이 (최악: O(N-1))
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n,m = list(map(int,input().split()))

parent = [i for i in range(n+1)]

def union(a,b):
    a = find(a)
    b = find(b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

def find(a):
    if parent[a] == a:
        return a
    else:
        return find(parent[a])
for _ in range(m):
    command, a, b = list(map(int, input().split()))
    # a의 집합과 b의 집합을 합친다
    if command == 0:
        union(a,b)
    # a와 b가 같은 집합인지
    else:
        print('YES') if find(a) == find(b) else print('NO')

기본.

최적화

union-by-rank

각 트리에 대해 높이(rank)를 기억해두고,
union 시 두 트리의 rank 가 다르면 높이가 작은 트리를 높이가 트리의 루트에 붙인다. -> 길이는 높은 쪽의 depth

높이가 같은 트리를 합칠 때에는 한 쪽 트리 높이를 1 증가시키고 다른 트리를 해당 트리에 붙인다. -> rank 에 대힌 집합을 따로 만든다.

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n,m = list(map(int,input().split()))

parent = [i for i in range(n+1)]
rank = [0 for _ in range(n+1)]

def union(a,b):
    a = find(a)
    b = find(b)
    if a == b:
        return

    # 높이가 낮은 트리를 높은 트리 밑에 넣는다
    if rank[a] < rank[b]:
        parent[a] = b
    else:
        parent[b] = a
    if rank[a] == rank[b]:
        rank[a]+=1

def find(a):
    if parent[a] == a:
        return a
    else:
        return find(parent[a])
for _ in range(m):
    command, a, b = list(map(int, input().split()))
    # a의 집합과 b의 집합을 합친다
    if command == 0:
        union(a,b)
    # a와 b가 같은 집합인지
    else:
        print('YES') if find(a) == find(b) else print('NO')

path compression

find를 실행한 노드에서 거쳐간 노드를 root 에 direct 로 연결

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n,m = list(map(int,input().split()))

parent = [i for i in range(n+1)]

def union(a,b):
    a = find(a)
    b = find(b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

def find(a):
    if parent[a] == a:
        return a
    else:
        parent[a] = find(parent[a])
        return parent[a]
for _ in range(m):
    command, a, b = list(map(int, input().split()))
    # a의 집합과 b의 집합을 합친다
    if command == 0:
        union(a,b)
    # a와 b가 같은 집합인지
    else:
        print('YES') if find(a) == find(b) else print('NO')

관련문제
1. BOJ 1717 집합의 표현: https://www.acmicpc.net/problem/1717

profile
개발블로그

0개의 댓글