[Python] Minimum Spanning Tree

이재원·2023년 9월 30일

Algorithm

목록 보기
9/29

Kruskal's Algorithm

  • 사이클을 형성하지 않으면서 비용이 최소인 순서대로 N-1개의 간선을 선택(N은 노드의 개수)
# Kruskal's Algorithm Baseline Code

import sys

# Union Find
# input x의 부모를 찾는 함수
def find_parent(parent, x):

    if parent[x] != x:

        parent[x] = find_parent(parent, parent[x])
    
    return parent[x]

# input a, b를 합치는 함수
def union_parent(parent, a, b):

    a = find_parent(parent, a)
    b = find_parent(parent, b)

    # 작은 번호가 부모, 큰 번호가 자식
    if a < b:

        parent[b] = a
    
    else:

        parent[a] = b

# Kruskal's Algorithm : 비용이 최소이면서 사이클을 형성하지 않는 간선을 선택하는 알고리즘
def kruskal():

    # 간선을 오름차순 정렬
    edges.sort()

    # MST 간선의 총합
    total = 0

    # 작은 간선부터 차례대로 확인합니다.
    for edge in edges:

        cost, a, b = edge

        # 두 노드의 부모가 같지 않을 때만, 즉 사이클을 형성하지 않을 때 간선을 선택합니다.
        if find_parent(parent, a) != find_parent(parent, b):

            # MST에 반영
            union_parent(parent, a, b)

            # 총합에 해당 간선의 cost를 반영
            total += cost
    
    return total

# 노드의 개수와 간선(union)의 개수 입력받기
v, e = map(int, sys.stdin.readline().split())
parent = [0] * (v+1)

for i in range(1, v+1):

    parent[i] = i

edges = []

for _ in range(e):

    a, b, cost = map(int, sys.stdin.readline().split())

    edges.append((cost, a, b))

# kruskal()
print(kruskal())

Prim's Algorithm

  • 임의의 정점에서 시작하고, 사이클을 형성하지 않으면서 비용이 최소인 간선을 매 순간 선택, 총 N-1개의 간선(N은 노드의 개수)
# Prim's Algorithm Baseline Code
import sys
from heapq import heappush, heappop

# Prim's Algorithm
def Prim(start):

    # 전체비용
    total = 0

    # 우선순위 큐
    q = []

    # Minimum Spanning Tree
    mst = []

    # 시작 노드 처리
    mst.append(start)

    # 시작노드의 이웃노드와의 간선을 큐에 넣습니다.
    for neighbor in graph[start]:

        heappush(q, (neighbor[1], neighbor[0]))
    
    # 큐가 빌 때까지 반복합니다.
    while q:

        # 큐에서 POP
        cost, cur = heappop(q)

        # 이미 mst에 있다면 해당 간선 선택시 사이클을 형성하므로 skip
        if cur in mst:

            continue

        else:

            # 총계에 누적
            total += cost

            # mst에 추가
            mst.append(cur)

            # POP한 노드의 이웃노드 탐색
            for neighbor in graph[cur]:

                # 방문하지 않은 노드가 있을 때
                if neighbor[0] not in mst:

                    # 큐에 추가 (비용, 노드번호)
                    heappush(q, (neighbor[1], neighbor[0]))

    # 총 비용 반환
    return total

# 노드의 개수 N, 간선의 개수 M
N, M = map(int, sys.stdin.readline().split())

# 그래프 초기화 (1 ~ N)
graph = [[] for _ in range(N+1)]

# 간선 입력
for _ in range(M):

    a, b, cost = map(int, sys.stdin.readline().split())

    # 그래프에 반영, 양방향 간선
    graph[a].append((b, cost))
    graph[b].append((a, cost))

# Execute, 어떤 노드에서 시작하던지 결과는 똑같음
Prim(1)

0개의 댓글