[PS] 서로소 집합, 최소신장 트리

방법이있지·2025년 5월 31일
post-thumbnail

"서로" "소"가 서 있네요
ㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋ

서로소 집합

  • 공통 원소가 없는 두 집합
  • e.g., 집합 {1, 2, 3}{4, 5, 6}은 서로소 관계
  • e.g., 집합 {1, 2}{2, 3, 4}는 서로소 관계가 아님 -> 공통 원소 2를 가짐

서로소 집합 자료구조

  • 서로소 집합으로 나누어진 원소들의 데이터를 처리하기 위한 자료구조
  • 각 원소를 노드로 표현하며, 동일 집합에 속한 노드는 간선으로 연결된 트리로 구성됨
    • 트리의 루트 노드는 해당 집합의 대표 원소 역할을 함
    • 즉 각 노드의 루트를 확인하면, 어느 집합에 속해 있는지 확인할 수 있음
    • 노드의 자식 간 순서는 중요하지 않음
  • 두 종류의 연산 존재
    • 찾기 (Find): 특정 원소의 루트 노드를 알려줌 (즉 어느 집합에 속해 있는지 알려줌)
    • 합집합 (Union): 두 원소가 속한 집합을 하나의 집합으로 합침

찾기 연산

  • 각 노드의 부모를 저장하는 부모 테이블을 이용해, 특정 노드의 루트를 찾음
    • 루트는 부모가 없으므로, 자기 자신을 저장
    • 즉 루트를 찾을 땐, 계속해서 자신의 부모를 찾으면 됨
  • 맨 처음엔 보통 모든 원소가, 자기 자신이 루트인 각각의 집합을 이룸
    • 즉 부모가 자신이 되도록 초기화

합집합 연산

  • 원소 AB의 합집합 연산은, A가 속한 집합과 B가 속한 집합을 합침
  • (1) 노드 A의 루트 노드와, 노드 B의 루트 노드를 찾음
  • (2) 둘 중 한 루트를 다른 루트의 부모로 설정
    • 일반적으로 더 작은 번호 루트를, 더 큰 번호 루트의 부모로 지정

예시

  • 일련의 합집합 연산을 계속 진행하는 예시

구현

입력

N = 6   # 노드의 개수
parent = [i for i in range(N + 1)] # 부모 테이블
  • 맨 처음에는 6개의 노드가 각각 서로 다른 집합에 있음
  • 루트 노드 역시 자기 자신이 루트 노드
  • parent는 편의상 0~6번 인덱스를 갖는 길이 7의 리스트로 만들었으며, 0번 인덱스는 사용하지 않음

찾기 연산

# 부모 테이블 parent를 이용해, x의 루트 노드 찾기
def find(parent, x):
    # 루트를 찾을 때까지 (부모가 자기 자신일 때까지)
    if parent[x] != x:
        # 자신의 부모의 루트를 찾기
        return find(parent, parent[x])
    return x
  • 노드 x의 루트를 찾을 때까지, 즉 x의 부모가 자기 자신일 때까지 재귀 호출
  • 부모가 자기 자신이 아닌 경우, 부모를 따라가며 루트를 찾아 반환

합집합 연산 수행

  • 루트 노드가 같으면, 사이클이 발생한 것- 모든 간선에 대해 위 과정을 반복
# 두 원소가 속한 집합 합치기
def union(parent, a, b):
    # 두 원소의 루트 찾기
    a = find(parent, a)
    b = find(parent, b)

    # 한쪽 루트를 다른쪽 루트의 부모로 설정
    if a < b:
        parent[b] = a
    else:
        parent[a] = b
  • find로 두 원소의 루트를 찾고, 더 큰 노드의 부모를 작은 노드로 지정

연산 결과

union_list = [(1, 4), (2, 3), (2, 4), (5, 6)] # 실행할 합집합 연산 수행
- 루트 노드가 같으면, 사이클이 발생한 것- 모든 간선에 대해 위 과정을 반복

 목록

for a, b in union_list:
    union(parent, a, b)

for i in range(1, N + 1):
    print(f"{i}번 노드: {find(parent, i)}번 노드가 루트인 집합에 속함")
print()

# 1번 노드: 1번 노드가 루트인 집합에 속함
# 2번 노드: 1번 노드가 루트인 집합에 속함
# 3번 노드: 1번 노드가 루트인 집합에 속함
# 4번 노드: 1번 노드가 루트인 집합에 속함
# 5번 노드: 5번 노드가 루트인 집합에 속함
# 6번 노드: 5번 노드가 루트인 집합에 속함

경로 압축

  • 찾기 연산에선, 최악의 경우 찾기 연산 시 모든 노드를 다 확인하게 됨 -> 원소의 수가 NN개일 때 O(N)O(N)

  • 찾기 연산을 최적화하기 위해 경로 압축 이용 가능
  • find 함수만 바꾸어 주면 됨
def find(parent, x):
    # 루트를 찾을 때까지 (부모가 자기 자신일 때까지)
    if parent[x] != x:
        # 자신의 부모의 루트를 찾고, 결과를 저장
        parent[x] = find(parent, parent[x])
    return parent[x]
  • find() 함수로 자신의 부모의 루트를 찾은 뒤, 이 값을 부모 테이블에 갱신
  • 이렇게 구현하면, find() 함수 실행 시 부모 테이블의 값이 루트 노드로 갱신됨
  • 즉 다음부턴 루트 노드를 바로 확인할 수 있음

시간 복잡도

  • 찾기 연산: 경로 압축을 사용하면, 바로 루트노드 확인이 가능하니 거의 O(1)O(1)
    • 단, 경로 압축이 반복적으로 이루어져야 트리가 납작해지므로, 처음 몇 번은 보다 오래 걸림
  • 합집합 연산: 내부에서 두 노드에 대한 찾기 연산을 수행 -> 마찬가지로 거의 O(1)O(1)

사이클 판별 알고리즘

  • 서로소 자료구조를 통해, 무방향 그래프사이클 여부를 사용할 수 있음

원리

  • 각 간선을 하나씩 확인하며, 찾기 연산으로 루트 노드 확인
    • 루트 노드가 다르면 (다른 집합에 속해 있으면), 두 노드에 대해 합집합 연산 수행
    • 루트 노드가 같으면 (같은 집합에 속해 있으면), 사이클이 존재
  • 모든 간선에 대해 위 과정을 반복

예시

구현

  • find, union 함수는 앞선 코드와 동일
  • 그래프를 인접행렬 / 리스트로 구현하진 않아도 되며, 각 간선에 연결된 노드만 (노드 1, 노드 2) 꼴로 edges 리스트에 포함
def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a = find(parent, a)
    b = find(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b


# N: 노드의 수, edges: 간선 리스트
def find_cycle(N, edges):
    # 부모 테이블
    parent = [i for i in range(N + 1)]

    # 간선으로 이어진 두 노드 a, b
    for a, b in edges:
        # 동일 집합에 있으면 사이클 존재
        if find(parent, a) == find(parent, b):
            return True
        # 다른 집합에 있으면 합집합 연산
        union(parent, a, b)
    return False

print(find_cycle(3, [(1, 2), (1, 3), (2, 3)]))  # True
print(find_cycle(4, [(1, 2), (2, 3), (3, 4)])) # False

시간 복잡도

  • 간선의 개수가 EE개일 때, 최대 EE번 찾기, 합집합 연산
    • 찾기, 합집합 연산의 시간 복잡도는 O(1)O(1)에 근접
  • 최종 O(E)O(E)에 근접

최소 신장 트리

신장 트리

  • 그래프에서 모든 노드를 포함하면서, 사이클이 존재하지 않게끔, 일부 간선만 남긴 부분 그래프
  • 왜 "트리"라고 부르냐고요?
    • 트리는 모든 노드가 서로 연결되면서, 사이클이 존재하지 않는 그래프의 일종이기 때문
  • 신장 트리의 간선 개수 = 노드 개수 - 1

최소 신장 트리

  • 가중치의 합이 최소가 되도록, 간선을 남긴 신장 트리

크루스칼 알고리즘

  • 그래프의 최소 신장 트리를 구하는 알고리즘
  • 매번 최저 가중치의 간선을 선택하며 동작
  • 간선의 가중치가 음수여도 사용 가능

과정

  • (1) 간선 데이터를 가중치에 따라 오름차순으로 정렬

  • (2) 가중치가 낮은 간선부터, 현재 간선이 사이클을 발생시키는지 확인
    • 두 노드의 부모를 확인한 후, 동일하면 사이클 발생
    • 발생하지 않는 경우, 신장 트리에 간선을 포함하고, 두 노드에 대해 합집합 연산 수행
  • (3) 모든 간선에 대해 과정을 반복

  • 위 최소신장트리의 가중치 합은 2 + 3 + 4 + 5 + 7 = 21

구현

  • find, union 함수는 앞선 코드와 동일
  • 그래프를 인접행렬 / 리스트로 구현하진 않아도 되며, 각 간선의 (가중치, 노드 1, 노드 2)edges 리스트에 포함
    • list.sort()는 기본적으로 튜플의 첫 원소를 기준으로 정렬하므로, 가중치를 맨 앞에 둠
  • 본 코드는 최소 신장 트리의 모든 간선 가중치 합을 반환
def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a = find(parent, a)
    b = find(parent, b)

    if a < b:
        parent[b] = a
    else:
        parent[a] = b


N = 6
parent = [i for i in range(N + 1)]  # 부모 테이블

# 간선 정보: (가중치, 노드1, 노드2)
edges = [(9, 1, 2), (7, 1, 4), (2, 2, 3), (8, 2, 4), (6, 3, 4), (4, 3, 5), (3, 3, 6), (5, 4, 5)]
edges.sort() # 기본적으로 튜플의 첫 순서 기준으로 정렬됨

# 최소신장트리의 가중치 합
answer = 0

for cost, a, b in edges:
    # 사이클 발생 시 포함 X
    if find(parent, a) == find(parent, b):
        continue
    else:
        answer += cost # answer에 가중치를 더함
        union(parent, a, b)

print(answer) # 21

시간 복잡도

  • EE개의 간선을 정렬할 때, O(ElogE)O(E \log E)
  • 이후 최대 EE번 찾기, 합집합 연산
    • 찾기, 합집합 연산의 시간 복잡도는 O(1)O(1)에 근접
    • 따라서 O(E)O(E)
  • 최종 O(ElogE)O(E \log E)
  • 중복된 간선이 없을 시 EV2E \leq V^2이므로, O(ElogV2)=O(2ElogV)=O(ElogV)O(E \log V^2) = O(2E \log V) = O(E \log V)

문제풀이

1197. 최소 스패닝 트리

백준 / 골드 4 / 1197. 최소 스패닝 트리

  • 가중치가 음수일 수 있다는 조건에서 당황했을 수도 있는데, 음수 있으면 못 쓰는 다익스트라랑 다르게 크루스칼 알고리즘은 음수 가중치가 있어도 사용 가능합니다. 그러니까 배운 대로 푸세용
  • 그리고 찾기 (find) 연산에서 재귀가 많이 발생하게 되므로, sys.setrecursionlimit(10**6) 설정 잊지 않기
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)

def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a = find(parent, a)
    b = find(parent, b)
    
    if a < b:
        parent[b] = a
    else:
        parent[a] = b
        
V, E = map(int, input().split())
parent = [i for i in range(V + 1)]
edges = []


# 가중치, 노드1, 노드2 순
for _ in range(E):
    a, b, cost = map(int, input().split())
    edges.append((cost, a, b))
    
edges.sort()
answer = 0
# 각 가중치에 대해..
for cost, a, b in edges:
    if find(parent, a) == find(parent, b):
        continue
    else:
        answer += cost
        union(parent, a, b)
        
print(answer)
profile
뭔가 만드는 걸 좋아하는 개발자 지망생입니다. 프로야구단 LG 트윈스를 응원하고 있습니다.

2개의 댓글

comment-user-thumbnail
2025년 5월 31일

힘내세요...

1개의 답글