그래프의 모든 노드들을 확인하면서 비용은 최소화하고 싶을 때는 최소 신장 트리를 사용하면 됩니다.
여기서 말하는 최소 신장 트리란 무엇인지와, 이 트리를 찾는데 사용되는 알고리즘인 크루스칼 알고리즘, 프림 알고리즘까지 정리했습니다😊
최소 신장 트리(MST: Minimum Spanning Tree, 최소 스패닝 트리)
: 가중치가 부여된 연결 그래프에서 모든 노드를 연결하는 간선들의 부분 집합인데, 이때 그 가중치의 합이 최소가 되는 트리
n개일 때, 간선은 n - 1개ex) 모든 도시를 연결하는 최소 비용의 도로 설계, 최소한의 비용으로 모든 지역의 전력망 구축, 통신 네트워크 최적화
간선 <= 10 * 정점일 경우 추천
크루스칼 알고리즘진행 방식
1. 모든 간선을 가중치 기준으로 오름차순 정렬
2. 작은 가중치의 간선부터 하나씩 선택
⠀ ⠀ if) 간선을 추가했을 때 사이클이 형성되면 제외
3. 총n-1개의 간선이 선택되면 종료됨
# edges = [노드1, 노드2, 가중치]
def kruskal(n, edges):
edges.sort(key=lambda x: x[2]) # 간선을 가중치 기준으로 정렬
parent = [i for i in range(n)]
# find 함수: 특정 노드가 속한 집합(루트 노드)을 찾음
def find(x):
if parent[x] != x:
parent[x] = find(parent[x]) # 경로 압축
return parent[x]
# union 함수: 두 집합을 병합
def union(a, b):
root_a = find(a)
root_b = find(b)
if root_a != root_b:
parent[root_b] = root_a
mst_cost = 0
for a, b, cost in edges: # 정렬된 간선을 순서대로 확인
if find(a) != find(b): # 두 노드가 다른 집합에 속하면 간선을 선택
union(a, b)
mst_cost += cost
return mst_costUnion & Find 알고리즘
서로소 집합(Disjoint Set)알고리즘:
- 두 노드가 같은 집합인지 확인, 집합 병합 할 때 사용
find: 특정 노드가 속한 집합(루트 노드)을 찾는 함수
-> 경로 압축으로 속한 집합을 효율적으로 찾음
-> 현재 노드의 부모를 직접 루트 노드로 설정하는 것union: 두 집합을 병합하는 함수
-> 각 집합의 높이(랭크)를 기준으로 병합 (루트 노드를 바꾸는 것)
-> 낮은 트리를 높은 트리에 붙임 (같으면 아무데나 붙이고 랭크 += 1)- 크루스칼 알고리즘에서 사이클이 발생 여부 확인할 때도 사용 (두 노드가 같은 집합에 속하면 사이클 발생)
- Union & Find 알고리즘 예시
# 루트 노드를 찾는 함수 def find(parent, x): if parent[x] != x: parent[x] = find(parent, parent[x]) # 경로 압축 return parent[x] ⠀ # 두 집합을 병합하는 함수 def union(parent, rank, a, b): root_a = find(parent, a) root_b = find(parent, b) if root_a != root_b: # 루트 노드가 같지 않을 경우 랭크 길이가 긴 쪽에 붙임 (루트 노드 변경) if rank[root_a] > rank[root_b]: parent[root_b] = root_a elif rank[root_a] < rank[root_b]: parent[root_a] = root_b else: # 같으면 아무데나 붙이고 랭크 길이 += 1 parent[root_b] = root_a rank[root_a] += 1
간선 > 10 * 정점일 경우 추천
프림 알고리즘진행 방식
1. 임의의 노드를 시작점으로 선택
2. 해당 노드에서 연결 가능한 간선 중 가중치가 최소인 간선 선택
3. 새로 추가된 노드에서 다시 연결 가능한 최소 간선 탐색
4. 모든 노드가 포함될 때까지 반복
프림 알고리즘 예시
import heapq
def prim(n, edges):
# 그래프를 인접 리스트 형태로 변환
graph = {i: [] for i in range(n)}
for a, b, cost in edges:
graph[a].append((cost, b))
graph[b].append((cost, a))
visited = [False] * n # 방문 여부를 저장
pq = [(0, 0)] # 우선순위 큐: (가중치, 노드) 순서
mst_cost = 0 # 최소 스패닝 트리의 총 비용
while pq:
cost, node = heapq.heappop(pq) # 가중치가 가장 작은 간선을 꺼냄
if visited[node]:
continue # 이미 방문한 노드라면 스킵
visited[node] = True
mst_cost += cost # 간선 비용을 누적
for next_cost, next_node in graph[node]:
if not visited[next_node]:
heapq.heappush(pq, (next_cost, next_node)) # 인접 노드를 우선순위 큐에 추가
return mst_cost
우선순위 큐
- 우선순위를 기준으로 값을 정렬하고, 가장 우선순위가 높은 값을 빠르게 꺼낼 수 있는 자료구조
-> 파이썬에서는heapq으로 최소 힙 형태의 우선순위 큐 사용 가능- 프림 알고리즘에서 간선 비용을 기준으로 간선을 정렬 후, 최소 가중치 간선을 선택할 때도 사용
- 우선순위 큐 예시
import heapq ⠀ pq = [] # 우선순위 큐 heapq.heappush(pq, (3, 'C')) # (우선순위, 값) 순로 넣어야함 heapq.heappush(pq, (1, 'A')) heapq.heappush(pq, (2, 'B')) ⠀ while pq: print(heapq.heappop(pq)) # 우선순위가 낮은 순서대로 출력
차이점:
| 특징 | 크루스칼 알고리즘 | 프림 알고리즘 |
|---|---|---|
| 접근 방식 | 간선 중심 (최소 가중치 간선 계속 추가) | 노드 중심 (최소 가중치 간선 선택하며 트리 확장 ) |
| 초기 상태 | 일단 모든 간선을 정렬 후 처리 | 임의의 시작 노드에서 시작 |
| 유리한 그래프 | 적은 간선 수(Sparse Graph) | 많은 간선 수(Dense Graph) |
| 보조 도구 | Union-Find 알고리즘(사이클 확인) | 우선순위 큐(간선 선택 최적화) |
문제 리스트
- 백준: 1197번(최소 스패닝 트리), 1647번(도시 분할 계획), 1922번(네트워크 연결), 21924번(도시 건설), 1774번(우주신과의 교감)
- 프로그래머스: 섬 연결하기, 전력망을 둘로 나누기
import sys
v, e = list(map(int, sys.stdin.readline().strip().split()))
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(e)]
edges.sort(key = lambda x: x[2])
parent = list(range(v + 1))
rank = [0] * (v + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
result = 0
for (a, b, cost) in edges:
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
result += cost
print(result)
📢 풀이 설명
기본적인 MST 문제다. 문제에서 주어진 간선 수가 최대 100,000개이기 때문에 크루스칼 알고리즘으로도 무리없다고 생각했다.
기본적인 MST를 만드는 문제라서 위에 나온 크루스칼 알고리즘 예시코드 그대로 풀어도 무리 없었다.
import sys
n, m = list(map(int, sys.stdin.readline().strip().split()))
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
edges.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
total = 0
maxCost = 0
for (a, b, cost) in edges:
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
total += cost
maxCost = cost
print(total - maxCost)
📢 풀이 설명
문제가 길지만 요약하면 아래와 같다.
정점 100,000개 & 간선 1,000,000개여서 크루스칼 알고리즘으로 풀었다. MST 만들면서 최대 비용을 계속 갱신한 후, 마지막에 total에서 그 값을 빼줬다.
# 크루스칼 알고리즘 버전
import sys
n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
edges.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
total = 0
for (a, b, cost) in edges:
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
total += cost
print(total)
# 프림 알고리즘 버전
import sys
import heapq
n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
graph = [[] for _ in range(n)]
for _ in range(m):
a, b, cost = map(int, sys.stdin.readline().strip().split())
graph[a - 1].append((cost, b - 1))
graph[b - 1].append((cost, a - 1))
pq = [(0, 0)]
visited = [False] * n
total = 0
while pq:
cost, node = heapq.heappop(pq)
if visited[node]: continue
visited[node] = True
total += cost
for (nCost, nNode) in graph[node]:
if not visited[nNode]:
heapq.heappush(pq, (nCost, nNode))
print(total)
📢 풀이 설명
기본적인 MST 문제다. 이 문제가 MST임을 알 수 있는 키워드는 모든 컴퓨터가 연결이 되어 있어야 한다., 모든 컴퓨터를 연결하는데 필요한 최소비용이었다.
크루스칼 알고리즘과 프림 알고리즘 2가지 버전으로 풀어봤다. 밀집 간선이 아니라서 크루스칼로 푼 버전이 더 효율적이다.
import sys
n, m = list(map(int, sys.stdin.readline().strip().split()))
graph = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
graph.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootB] > rank[rootA]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
total = 0
save = 0
for (a, b, cost) in graph:
total += cost
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
save += cost
curParent = parent[1]
for i in range(2, n + 1):
if parent[i] != curParent:
rootCur = find(curParent)
rootI = find(parent[i])
if rootI != rootCur:
print(-1)
sys.exit()
print(total - save)
📢 풀이 설명
이 문제는 특히 더 지문을 꼼꼼히 읽어야한다. 기본적인 MST 문제인데 2가지 함정이 있기 때문... (제가 2번 틀려서 그런 건 아닐 겁니다.,, 아마두..)
이 부분만 주의해서 풀면 금방 풀 수 있다. 그리고 정점 * 10 >= 간선이어서 크루스칼 알고리즘을 이용했다.
📢 풀이 설명