최소 신장 트리(MST), 크루스칼 알고리즘과 프림 알고리즘

송히·2024년 6월 14일
0

개발 공부 🐨

목록 보기
14/15
post-thumbnail

그래프의 모든 노드들을 확인하면서 비용은 최소화하고 싶을 때는 최소 신장 트리를 사용하면 됩니다.

여기서 말하는 최소 신장 트리란 무엇인지와, 이 트리를 찾는데 사용되는 알고리즘인 크루스칼 알고리즘, 프림 알고리즘까지 정리했습니다😊


1. 최소 신장 트리 (MST)

최소 신장 트리(MST: Minimum Spanning Tree, 최소 스패닝 트리)
: 가중치가 부여된 연결 그래프에서 모든 노드를 연결하는 간선들의 부분 집합인데, 이때 그 가중치의 합이 최소가 되는 트리

  • 하나의 연결 그래프 (모든 노드가 연결된 1개의 트리 구조)
  • 노드n개일 때, 간선n - 1개
  • 사이클이 없어야함 (순환하는 노드들이 있으면 안 됨)

ex) 모든 도시를 연결하는 최소 비용의 도로 설계, 최소한의 비용으로 모든 지역의 전력망 구축, 통신 네트워크 최적화



2. MST 찾는 2가지 알고리즘

2-1. 크루스칼 알고리즘 (Kruskal's Algorithm)

  • 간선 중심 접근법
  • 모든 간선을 오름차순으로 정렬한 뒤, 사이클이 형성되지 않도록 간선을 선택해나가며 MST를 구성
    -> 간선 선택 최적화 필요
  • 간선 수가 적은 그래프에서 유리 (Sparse Graph)
    => 간선 <= 10 * 정점일 경우 추천

크루스칼 알고리즘 진행 방식
1. 모든 간선을 가중치 기준으로 오름차순 정렬
2. 작은 가중치의 간선부터 하나씩 선택
⠀ ⠀ if) 간선을 추가했을 때 사이클이 형성되면 제외
3. 총 n-1개의 간선이 선택되면 종료됨

  • 크루스칼 알고리즘 예시
    # edges = [노드1, 노드2, 가중치]
    def kruskal(n, edges): 
        edges.sort(key=lambda x: x[2]) # 간선을 가중치 기준으로 정렬
        parent = [i for i in range(n)]
        
        # find 함수: 특정 노드가 속한 집합(루트 노드)을 찾음
        def find(x):
            if parent[x] != x:
                parent[x] = find(parent[x])  # 경로 압축
            return parent[x]
        
        # union 함수: 두 집합을 병합
        def union(a, b):
            root_a = find(a)
            root_b = find(b)
            if root_a != root_b:
                parent[root_b] = root_a
        
        mst_cost = 0
        
        for a, b, cost in edges: # 정렬된 간선을 순서대로 확인
            if find(a) != find(b): # 두 노드가 다른 집합에 속하면 간선을 선택
                union(a, b)
                mst_cost += cost
        return mst_cost

Union & Find 알고리즘

  • 서로소 집합(Disjoint Set) 알고리즘:
    1. 두 노드가 같은 집합인지 확인, 집합 병합 할 때 사용
      • find: 특정 노드가 속한 집합(루트 노드)을 찾는 함수
        -> 경로 압축으로 속한 집합을 효율적으로 찾음
        -> 현재 노드의 부모를 직접 루트 노드로 설정하는 것
      • union: 두 집합병합하는 함수
        -> 각 집합의 높이(랭크)를 기준으로 병합 (루트 노드를 바꾸는 것)
        -> 낮은 트리를 높은 트리에 붙임 (같으면 아무데나 붙이고 랭크 += 1)
    2. 크루스칼 알고리즘에서 사이클이 발생 여부 확인할 때도 사용 (두 노드가 같은 집합에 속하면 사이클 발생)

  • Union & Find 알고리즘 예시
    # 루트 노드를 찾는 함수
    def find(parent, x):
        if parent[x] != x: parent[x] = find(parent, parent[x])  # 경로 압축
        return parent[x]# 두 집합을 병합하는 함수
    def union(parent, rank, a, b):
        root_a = find(parent, a)
        root_b = find(parent, b)
        if root_a != root_b: # 루트 노드가 같지 않을 경우 랭크 길이가 긴 쪽에 붙임 (루트 노드 변경)
            if rank[root_a] > rank[root_b]: parent[root_b] = root_a
            elif rank[root_a] < rank[root_b]: parent[root_a] = root_b
            else: # 같으면 아무데나 붙이고 랭크 길이 += 1
                parent[root_b] = root_a 
                rank[root_a] += 1

2-2. 프림 알고리즘 (Prim's Algorithm)

  • 노드 중심 접근법
  • 시작 노드에서 출발해 가장 가중치가 작은 간선을 추가해가면서 트리를 확장
    -> 노드 연결 최적화 필요
  • 간선 수가 많은 그래프에서 유리(Dense Graph)
    => 간선 > 10 * 정점일 경우 추천

프림 알고리즘 진행 방식
1. 임의의 노드를 시작점으로 선택
2. 해당 노드에서 연결 가능한 간선 중 가중치가 최소인 간선 선택
3. 새로 추가된 노드에서 다시 연결 가능한 최소 간선 탐색
4. 모든 노드가 포함될 때까지 반복

  • 프림 알고리즘 예시

    import heapq
    
    def prim(n, edges):
        # 그래프를 인접 리스트 형태로 변환
        graph = {i: [] for i in range(n)}
        for a, b, cost in edges:
            graph[a].append((cost, b))
            graph[b].append((cost, a))
        
        visited = [False] * n  # 방문 여부를 저장
        pq = [(0, 0)]  # 우선순위 큐: (가중치, 노드) 순서
        mst_cost = 0  # 최소 스패닝 트리의 총 비용
    
        while pq:
            cost, node = heapq.heappop(pq)  # 가중치가 가장 작은 간선을 꺼냄
            if visited[node]:
                continue  # 이미 방문한 노드라면 스킵
            visited[node] = True
            mst_cost += cost  # 간선 비용을 누적
    
            for next_cost, next_node in graph[node]:
                if not visited[next_node]:
                    heapq.heappush(pq, (next_cost, next_node))  # 인접 노드를 우선순위 큐에 추가
    
        return mst_cost

우선순위 큐

  • 우선순위를 기준으로 값을 정렬하고, 가장 우선순위가 높은 값을 빠르게 꺼낼 수 있는 자료구조
    -> 파이썬에서는 heapq으로 최소 힙 형태의 우선순위 큐 사용 가능
  • 프림 알고리즘에서 간선 비용을 기준으로 간선을 정렬 후, 최소 가중치 간선을 선택할 때도 사용

  • 우선순위 큐 예시
    import heapq
    ⠀ 
    pq = []  # 우선순위 큐
    heapq.heappush(pq, (3, 'C'))  # (우선순위, 값) 순로 넣어야함
    heapq.heappush(pq, (1, 'A'))
    heapq.heappush(pq, (2, 'B'))while pq:
        print(heapq.heappop(pq))  # 우선순위가 낮은 순서대로 출력

2-3. 크루스칼프림 알고리즘의 비교

  • 공통점:
    • 최소 신장 트리(MST)를 찾는 데 사용됨
    • 그리디 알고리즘 기반
  • 차이점:

    특징크루스칼 알고리즘프림 알고리즘
    접근 방식간선 중심 (최소 가중치 간선 계속 추가)노드 중심 (최소 가중치 간선 선택하며 트리 확장 )
    초기 상태일단 모든 간선을 정렬 후 처리임의의 시작 노드에서 시작
    유리한 그래프적은 간선 수(Sparse Graph)많은 간선 수(Dense Graph)
    보조 도구Union-Find 알고리즘(사이클 확인)우선순위 큐(간선 선택 최적화)


3. MST 관련 문제 추천 및 풀이

문제 리스트

  • 백준: 1197번(최소 스패닝 트리), 1647번(도시 분할 계획), 1922번(네트워크 연결), 21924번(도시 건설), 1774번(우주신과의 교감)

  • 프로그래머스: 섬 연결하기, 전력망을 둘로 나누기

🔍 1197번: 최소 스패닝 트리

📖 풀이 코드

import sys

v, e = list(map(int, sys.stdin.readline().strip().split()))
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(e)]
edges.sort(key = lambda x: x[2])
parent = list(range(v + 1))
rank = [0] * (v + 1)

def find(x):
    if x != parent[x]: parent[x] = find(parent[x])
    return parent[x]

def union(rootA, rootB):
    if rank[rootA] > rank[rootB]: parent[rootB] = rootA
    elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
    else:
        parent[rootB] = rootA
        rank[rootA] += 1

result = 0

for (a, b, cost) in edges:
    rootA = find(a)
    rootB = find(b)
    if rootA != rootB:
        union(rootA, rootB)
        result += cost

print(result)

📢 풀이 설명
기본적인 MST 문제다. 문제에서 주어진 간선 수가 최대 100,000개이기 때문에 크루스칼 알고리즘으로도 무리없다고 생각했다.

기본적인 MST를 만드는 문제라서 위에 나온 크루스칼 알고리즘 예시코드 그대로 풀어도 무리 없었다.



🔍 1647번: 도시 분할 계획

📖 풀이 코드

import sys

n, m = list(map(int, sys.stdin.readline().strip().split()))
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
edges.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)

def find(x):
    if x != parent[x]: parent[x] = find(parent[x])
    return parent[x]

def union(rootA, rootB):
    if rank[rootA] > rank[rootB]: parent[rootB] = rootA
    elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
    else:
        parent[rootB] = rootA
        rank[rootA] += 1

total = 0
maxCost = 0

for (a, b, cost) in edges:
    rootA = find(a)
    rootB = find(b)

    if rootA != rootB:
        union(rootA, rootB)
        total += cost
        maxCost = cost

print(total - maxCost)

📢 풀이 설명
문제가 길지만 요약하면 아래와 같다.

  1. MST 만들어서
  2. 가장 비용이 큰 간선 끊기

정점 100,000개 & 간선 1,000,000개여서 크루스칼 알고리즘으로 풀었다. MST 만들면서 최대 비용을 계속 갱신한 후, 마지막에 total에서 그 값을 빼줬다.



🔍 1922번: 네트워크 연결

📖 풀이 코드

# 크루스칼 알고리즘 버전
import sys

n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
edges.sort(key = lambda x: x[2])

parent = list(range(n + 1))
rank = [0] * (n + 1)

def find(x):
    if x != parent[x]: parent[x] = find(parent[x])
    return parent[x]

def union(rootA, rootB):
    if rank[rootA] > rank[rootB]: parent[rootB] = rootA
    elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
    else:
        parent[rootB] = rootA
        rank[rootA] += 1

total = 0

for (a, b, cost) in edges:
    rootA = find(a)
    rootB = find(b)

    if rootA != rootB:
        union(rootA, rootB)
        total += cost

print(total)

# 프림 알고리즘 버전
import sys
import heapq

n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
graph = [[] for _ in range(n)]
for _ in range(m):
    a, b, cost = map(int, sys.stdin.readline().strip().split())
    graph[a - 1].append((cost, b - 1))
    graph[b - 1].append((cost, a - 1))

pq = [(0, 0)]
visited = [False] * n
total = 0

while pq:
    cost, node = heapq.heappop(pq)
    if visited[node]: continue

    visited[node] = True
    total += cost

    for (nCost, nNode) in graph[node]:
        if not visited[nNode]:
            heapq.heappush(pq, (nCost, nNode))

print(total)

📢 풀이 설명
기본적인 MST 문제다. 이 문제가 MST임을 알 수 있는 키워드는 모든 컴퓨터가 연결이 되어 있어야 한다., 모든 컴퓨터를 연결하는데 필요한 최소비용이었다.
크루스칼 알고리즘과 프림 알고리즘 2가지 버전으로 풀어봤다. 밀집 간선이 아니라서 크루스칼로 푼 버전이 더 효율적이다.



🔍 21924번: 도시 건설

📖 풀이 코드

import sys

n, m = list(map(int, sys.stdin.readline().strip().split()))
graph = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
graph.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)

def find(x):
    if x != parent[x]: parent[x] = find(parent[x])
    return parent[x]

def union(rootA, rootB):
    if rank[rootA] > rank[rootB]: parent[rootB] = rootA
    elif rank[rootB] > rank[rootA]: parent[rootA] = rootB
    else:
        parent[rootB] = rootA
        rank[rootA] += 1

total = 0
save = 0
for (a, b, cost) in graph:
    total += cost
    rootA = find(a)
    rootB = find(b)

    if rootA != rootB:
        union(rootA, rootB)
        save += cost

curParent = parent[1]
for i in range(2, n + 1):
    if parent[i] != curParent:
        rootCur = find(curParent)
        rootI = find(parent[i])

        if rootI != rootCur:
            print(-1)
            sys.exit()

print(total - save)

📢 풀이 설명
이 문제는 특히 더 지문을 꼼꼼히 읽어야한다. 기본적인 MST 문제인데 2가지 함정이 있기 때문... (제가 2번 틀려서 그런 건 아닐 겁니다.,, 아마두..)

  1. 최소 비용이 아닌 절약된 값을 반환해야함
  2. 모든 정점들이 연결되어있지 않을 수도 있음

이 부분만 주의해서 풀면 금방 풀 수 있다. 그리고 정점 * 10 >= 간선이어서 크루스칼 알고리즘을 이용했다.



🔍 1774번: 우주신과의 교감

📖 풀이 코드

📢 풀이 설명



🔍 섬 연결하기

📖 풀이 보러 가기


🔍 전력망을 둘로 나누기

📖 풀이 보러 가기

profile
데브코스 프론트엔드 5기

0개의 댓글

관련 채용 정보