그래프의 모든 노드들을 확인하면서 비용은 최소화하고 싶을 때는 최소 신장 트리
를 사용하면 됩니다.
여기서 말하는 최소 신장 트리란 무엇인지와, 이 트리를 찾는데 사용되는 알고리즘인 크루스칼 알고리즘
, 프림 알고리즘
까지 정리했습니다😊
최소 신장 트리
(MST: Minimum Spanning Tree, 최소 스패닝 트리)
: 가중치가 부여된 연결 그래프에서 모든 노드를 연결하는 간선들의 부분 집합인데, 이때 그 가중치의 합이 최소가 되는 트리
n개
일 때, 간선은 n - 1개
ex) 모든 도시를 연결하는 최소 비용의 도로 설계, 최소한의 비용으로 모든 지역의 전력망 구축, 통신 네트워크 최적화
간선 <= 10 * 정점
일 경우 추천
크루스칼 알고리즘
진행 방식
1. 모든 간선을 가중치 기준으로 오름차순 정렬
2. 작은 가중치의 간선부터 하나씩 선택
⠀ ⠀ if) 간선을 추가했을 때 사이클이 형성되면 제외
3. 총n-1개
의 간선이 선택되면 종료됨
# edges = [노드1, 노드2, 가중치]
def kruskal(n, edges):
edges.sort(key=lambda x: x[2]) # 간선을 가중치 기준으로 정렬
parent = [i for i in range(n)]
# find 함수: 특정 노드가 속한 집합(루트 노드)을 찾음
def find(x):
if parent[x] != x:
parent[x] = find(parent[x]) # 경로 압축
return parent[x]
# union 함수: 두 집합을 병합
def union(a, b):
root_a = find(a)
root_b = find(b)
if root_a != root_b:
parent[root_b] = root_a
mst_cost = 0
for a, b, cost in edges: # 정렬된 간선을 순서대로 확인
if find(a) != find(b): # 두 노드가 다른 집합에 속하면 간선을 선택
union(a, b)
mst_cost += cost
return mst_cost
Union & Find 알고리즘
서로소 집합(Disjoint Set)
알고리즘:
- 두 노드가 같은 집합인지 확인, 집합 병합 할 때 사용
find
: 특정 노드가 속한 집합(루트 노드)을 찾는 함수
-> 경로 압축으로 속한 집합을 효율적으로 찾음
-> 현재 노드의 부모를 직접 루트 노드로 설정하는 것union
: 두 집합을 병합하는 함수
-> 각 집합의 높이(랭크)를 기준으로 병합 (루트 노드를 바꾸는 것)
-> 낮은 트리를 높은 트리에 붙임 (같으면 아무데나 붙이고 랭크 += 1)- 크루스칼 알고리즘에서 사이클이 발생 여부 확인할 때도 사용 (두 노드가 같은 집합에 속하면 사이클 발생)
- Union & Find 알고리즘 예시
# 루트 노드를 찾는 함수 def find(parent, x): if parent[x] != x: parent[x] = find(parent, parent[x]) # 경로 압축 return parent[x] ⠀ # 두 집합을 병합하는 함수 def union(parent, rank, a, b): root_a = find(parent, a) root_b = find(parent, b) if root_a != root_b: # 루트 노드가 같지 않을 경우 랭크 길이가 긴 쪽에 붙임 (루트 노드 변경) if rank[root_a] > rank[root_b]: parent[root_b] = root_a elif rank[root_a] < rank[root_b]: parent[root_a] = root_b else: # 같으면 아무데나 붙이고 랭크 길이 += 1 parent[root_b] = root_a rank[root_a] += 1
간선 > 10 * 정점
일 경우 추천
프림 알고리즘
진행 방식
1. 임의의 노드를 시작점으로 선택
2. 해당 노드에서 연결 가능한 간선 중 가중치가 최소인 간선 선택
3. 새로 추가된 노드에서 다시 연결 가능한 최소 간선 탐색
4. 모든 노드가 포함될 때까지 반복
프림 알고리즘 예시
import heapq
def prim(n, edges):
# 그래프를 인접 리스트 형태로 변환
graph = {i: [] for i in range(n)}
for a, b, cost in edges:
graph[a].append((cost, b))
graph[b].append((cost, a))
visited = [False] * n # 방문 여부를 저장
pq = [(0, 0)] # 우선순위 큐: (가중치, 노드) 순서
mst_cost = 0 # 최소 스패닝 트리의 총 비용
while pq:
cost, node = heapq.heappop(pq) # 가중치가 가장 작은 간선을 꺼냄
if visited[node]:
continue # 이미 방문한 노드라면 스킵
visited[node] = True
mst_cost += cost # 간선 비용을 누적
for next_cost, next_node in graph[node]:
if not visited[next_node]:
heapq.heappush(pq, (next_cost, next_node)) # 인접 노드를 우선순위 큐에 추가
return mst_cost
우선순위 큐
- 우선순위를 기준으로 값을 정렬하고, 가장 우선순위가 높은 값을 빠르게 꺼낼 수 있는 자료구조
-> 파이썬에서는heapq
으로 최소 힙 형태의 우선순위 큐 사용 가능- 프림 알고리즘에서 간선 비용을 기준으로 간선을 정렬 후, 최소 가중치 간선을 선택할 때도 사용
- 우선순위 큐 예시
import heapq ⠀ pq = [] # 우선순위 큐 heapq.heappush(pq, (3, 'C')) # (우선순위, 값) 순로 넣어야함 heapq.heappush(pq, (1, 'A')) heapq.heappush(pq, (2, 'B')) ⠀ while pq: print(heapq.heappop(pq)) # 우선순위가 낮은 순서대로 출력
차이점:
특징 | 크루스칼 알고리즘 | 프림 알고리즘 |
---|---|---|
접근 방식 | 간선 중심 (최소 가중치 간선 계속 추가) | 노드 중심 (최소 가중치 간선 선택하며 트리 확장 ) |
초기 상태 | 일단 모든 간선을 정렬 후 처리 | 임의의 시작 노드에서 시작 |
유리한 그래프 | 적은 간선 수(Sparse Graph) | 많은 간선 수(Dense Graph) |
보조 도구 | Union-Find 알고리즘(사이클 확인) | 우선순위 큐(간선 선택 최적화) |
문제 리스트
- 백준: 1197번(최소 스패닝 트리), 1647번(도시 분할 계획), 1922번(네트워크 연결), 21924번(도시 건설), 1774번(우주신과의 교감)
- 프로그래머스: 섬 연결하기, 전력망을 둘로 나누기
import sys
v, e = list(map(int, sys.stdin.readline().strip().split()))
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(e)]
edges.sort(key = lambda x: x[2])
parent = list(range(v + 1))
rank = [0] * (v + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
result = 0
for (a, b, cost) in edges:
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
result += cost
print(result)
📢 풀이 설명
기본적인 MST 문제다. 문제에서 주어진 간선 수가 최대 100,000개이기 때문에 크루스칼 알고리즘
으로도 무리없다고 생각했다.
기본적인 MST를 만드는 문제라서 위에 나온 크루스칼 알고리즘 예시코드 그대로 풀어도 무리 없었다.
import sys
n, m = list(map(int, sys.stdin.readline().strip().split()))
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
edges.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
total = 0
maxCost = 0
for (a, b, cost) in edges:
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
total += cost
maxCost = cost
print(total - maxCost)
📢 풀이 설명
문제가 길지만 요약하면 아래와 같다.
정점 100,000개 & 간선 1,000,000개여서 크루스칼 알고리즘으로 풀었다. MST 만들면서 최대 비용을 계속 갱신한 후, 마지막에 total에서 그 값을 빼줬다.
# 크루스칼 알고리즘 버전
import sys
n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
edges = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
edges.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootA] < rank[rootB]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
total = 0
for (a, b, cost) in edges:
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
total += cost
print(total)
# 프림 알고리즘 버전
import sys
import heapq
n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
graph = [[] for _ in range(n)]
for _ in range(m):
a, b, cost = map(int, sys.stdin.readline().strip().split())
graph[a - 1].append((cost, b - 1))
graph[b - 1].append((cost, a - 1))
pq = [(0, 0)]
visited = [False] * n
total = 0
while pq:
cost, node = heapq.heappop(pq)
if visited[node]: continue
visited[node] = True
total += cost
for (nCost, nNode) in graph[node]:
if not visited[nNode]:
heapq.heappush(pq, (nCost, nNode))
print(total)
📢 풀이 설명
기본적인 MST 문제다. 이 문제가 MST임을 알 수 있는 키워드는 모든 컴퓨터가 연결이 되어 있어야 한다.
, 모든 컴퓨터를 연결하는데 필요한 최소비용
이었다.
크루스칼 알고리즘과 프림 알고리즘 2가지 버전으로 풀어봤다. 밀집 간선이 아니라서 크루스칼로 푼 버전이 더 효율적이다.
import sys
n, m = list(map(int, sys.stdin.readline().strip().split()))
graph = [list(map(int, sys.stdin.readline().strip().split())) for _ in range(m)]
graph.sort(key = lambda x: x[2])
parent = list(range(n + 1))
rank = [0] * (n + 1)
def find(x):
if x != parent[x]: parent[x] = find(parent[x])
return parent[x]
def union(rootA, rootB):
if rank[rootA] > rank[rootB]: parent[rootB] = rootA
elif rank[rootB] > rank[rootA]: parent[rootA] = rootB
else:
parent[rootB] = rootA
rank[rootA] += 1
total = 0
save = 0
for (a, b, cost) in graph:
total += cost
rootA = find(a)
rootB = find(b)
if rootA != rootB:
union(rootA, rootB)
save += cost
curParent = parent[1]
for i in range(2, n + 1):
if parent[i] != curParent:
rootCur = find(curParent)
rootI = find(parent[i])
if rootI != rootCur:
print(-1)
sys.exit()
print(total - save)
📢 풀이 설명
이 문제는 특히 더 지문을 꼼꼼히 읽어야한다. 기본적인 MST 문제인데 2가지 함정이 있기 때문... (제가 2번 틀려서 그런 건 아닐 겁니다.,, 아마두..)
이 부분만 주의해서 풀면 금방 풀 수 있다. 그리고 정점 * 10 >= 간선
이어서 크루스칼 알고리즘을 이용했다.
📢 풀이 설명