MST(Minimum Spanning Tree)

이영구·2022년 9월 10일
0

Algorithm

목록 보기
13/19
  • 스패닝 트리
    : 그래프에서 일부 간선을 선택해서 만든 트리
    : 스패닝 트리는 트리의 특수한 형태이므로 모든 정점들이 연결되어 있어야 하고 사이클을 포함해서는 안됨 ( 트리의 정의가 사이클 없이 모든 정점이 연결되어 있는 그래프)
    : 스패닝 트리는 그래프에 있는 n개의 정점을 n-1개의 간선으로 연결

  • 최소 스패닝 트리
    : 그래프에서 tree를 만들 때, 스패닝 트리 중에 선택한 간선의 가중치의 합이 최소인 트리
    : n개의 정점을 가지는 그래프에 대해 반드시 n-1개의 간선만을 사용해야 함
    : 사이클이 포함되어서는 안됨
    : Prim 알고리즘과 Kruskal 알고리즘이 대표적임 (가장 중요한 것은 사이클을 만들지 않는다는 것)

  • Prim, Kruskal 알고리즘 비교
    : Prim 알고리즘 -> 시작점에서 끝점까지 단계적으로 확장, 정점 선택을 기준으로 최적해 찾기
    : Kruskal 알고리즘 -> 가중치를 간선에 할당, 간선 선택을 기준으로 최적해 찾기

  • 프림 알고리즘
    : 1. 그래프에서 임의의 정점 선택
    : 2. 선택한 정점(기존의 선택한 모든 정점)과 선택하지 않은 정점(다음 정점)을 연결하는 간선 중 최소값을 고름.
    : 3. 선택한 간선을 MST에 추가하고 다음 정점으로 v를 추가
    : 4. 모든 정점을 선택하지 않았다면, 2번 단계로 돌아감
    : (시간복잡도) (V-1개의 정점을 선택시 모든 간선 E 중에서 선택해야 하므로 시간 복잡도는 O(V*E), E의 최대값을 V2 이므로, O(V3 이라고 볼 수 있다.
    다만, 선택과 선택하지 않은 간선의 선택에서 우선순위 큐를 이용할 경우 O(ElogE)로 줄 일 수 있다.
    : 최소 스패닝 트리이므로, n-1개의 간선만 이용해야 하며, 한번 방문한 정점을 다시 방문할 일이 없기 때문에 이를 확인하기 위한 배열이 필요하다.

(예시) 백준 1922 네트워크 연결 문제

import sys
import heapq

rdln = sys.stdin.readline

def prim(v):
    q = []
    mst[v] = 1 # 방문했음을 표시하기 위한 배열, 1은 방문을 의미
    result = 0
    for i in adj[v]:
        heapq.heappush(q, i)
    
    while q:
        c, v = heapq.heappop(q)
        if not mst[v]:
            mst[v] = 1
            result += c
            for j in adj[v]:
                heapq.heappush(q, j)
        if sum(mst) == n:
            return result

n = int(rdln())
m = int(rdln())
adj = [[] for _ in range(n+1)]
mst = [0] * (n+1)
for _ in range(m):
    a, b, c = map(int, rdln().split())
    adj[a].append([c, b])
    adj[b].append([c, a])

print(prim(1))
  • Kruskal 알고리즘
    : Union - FInd를 이용하는 알고리즘
    : 사이클이 생기지 않게 MST에 추가하는 방식
    : 가중치를 모두 queue에 포함시켜 놓고 이 union - find를 이용해서 가중치를 이용할지 판단
    : 가중치를 추가했을 때, 사이클이 형성되면( 노드 2개가 이미 같은 그룹에 속할 경우) skip

(예시) 백준 1197 최소 스패닝 트리 문제

def find(i):
    if parent[i] != i:
        parent[i] = find(parent[i])
    return parent[i]

def union(irep, jrep):

    if rank[irep] <= rank[jrep]:
        parent[irep] = jrep
        if rank[irep] ==  rank[jrep]:
            rank[jrep] += 1
    else:
        parent[jrep] = irep

def Kruskal(graph):
    mst_cost = 0

    for i, j, wt in graph:
        irep, jrep = find(i), find(j)
        if irep != jrep:
            union(irep, jrep)
            mst_cost += wt
    
    return mst_cost

v, e = map(int, rdln().split())
graph = []
parent = [i for i in range(v+1)]
rank = [0] * (v+1)

for _ in range(e):
    a, b, c = map(int, rdln().split())
    graph.append((a, b, c))

graph.sort(key = lambda x:x[2])

sys.stdout.write(f"{Kruskal(graph)}")
profile
In to the code!

0개의 댓글