최소 비용으로 그래프의 모든 정점을 연결해 봅시다.
Java / Python
신장 트리 중에서도 가중치 합이 최소인 최소 신장 트리(MST)를 배우는 문제
이번 문제는 그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하는 문제이다.
크루스칼 알고리즘 기본문제라고 볼 수 있다.
크루스칼 알고리즘은 가장 적은 비용으로 모든 노드를 연결하기 위해 사용하는 알고리즘이다. 크루스칼의 기본은 간선을 중심으로 생각하는 것이며, 간선의 가중치가 가장 작은 것을 고르고 후에 싸이클이 생기지 않고 모든 노드를 방문할 수 있도록 고른다.
다음 과정을 간선의 개수만큼 반복한다.
1. 가중치가 가장 작은 간선 하나 poll
2. 시작 노드와 끝 노드의 최상위 노드 find (최상위 노드가 없는 경우, 자기 자신)
3. 최상위 노드가 다른 경우 union을 통해 그 간선 선택, result에 가중치를 더해준다.
parent를 자기 자신으로 초기화하고, 부모를 찾는 함수 find와 합집합 연산을 해, 같은 부모를 가지도록 하는 union함수를 이용한다.
import java.io.*;
import java.util.*;
public class Main{
static class Edge implements Comparable<Edge> {
int s, e, cost;
Edge(int s, int e, int cost) {
this.s = s;
this.e = e;
this.cost = cost;
}
@Override
public int compareTo(Edge arg0) {
// Comparable을 통해 정렬 우선순위 (cost 기준)
return arg0.cost >= this.cost ? -1 : 1;
}
}
static int[] parent;
static int V, E, result;
static PriorityQueue<Edge> pque = new PriorityQueue<Edge>();
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st = new StringTokenizer(br.readLine());
V = Integer.parseInt(st.nextToken());
E = Integer.parseInt(st.nextToken());
parent = new int[V + 1];
for (int i = 1; i <= V; i++) {
parent[i] = i;
}
for (int i = 0; i < E; i++) {
st = new StringTokenizer(br.readLine());
int start = Integer.parseInt(st.nextToken());
int end = Integer.parseInt(st.nextToken());
int cost = Integer.parseInt(st.nextToken());
pque.add(new Edge(start, end, cost));
}
for (int i = 0; i < E; i++) {
Edge tmp = pque.poll();
int a = find(tmp.s);
int b = find(tmp.e);
if (a != b) {
union(a, b);
result += tmp.cost;
}
}
bw.write(result + "\n");
bw.flush();
bw.close();
br.close();
}
// x의 부모 찾기
public static int find(int x) {
if (x == parent[x])
return x;
return parent[x] = find(parent[x]);
}
// y 부모를 x 부모로 치환하기 (x > y 일 경우 반대)
public static void union(int x, int y) {
x = find(x);
y = find(y);
if (x != y) {
if (x < y) {
parent[y] = x;
} else {
parent[x] = y;
}
}
}
}
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
def find(x):
if x == parent[x]:
return x
parent[x] = find(parent[x]) # 부모 테이블 갱신
return parent[x]
def union(x, y):
x = find(x)
y = find(y)
if x == y: # 동일한 집합일 경우
return
if x < y:
parent[y] = x
else:
parent[x] = y
V, E = map(int, input().split())
edge = []
for _ in range(E):
s, e, cost = map(int, input().split())
edge.append((s, e, cost))
edge.sort(key=lambda x : x[2])
parent = [i for i in range(V+1)]
result = 0
for s, e, c in edge:
if find(s) != find(e):
union(s, e)
result += c
print(result)