https://www.acmicpc.net/problem/1197
그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.
최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.
입력
첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.
그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.
출력
첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.
예제 입력 1
3 3
1 2 1
2 3 2
1 3 3
예제 출력 1
3
예전에 아무것도 모르고 이 문제를 풀려고 했다가 메모리 초과가 난 후에,, MST를 제대로 공부하고 다시 풀어봐야겠다! 하고 무려 9개월을 방치했는데 이제는 진짜 공부할 때가 되었다 싶어 MST에 대해 알아보고 풀어보았다.
Union-Find는 서로소 집합(Disjoint Set)을 관리하는 자료구조이다.
서로 다른 집합을 합치거나(Union), 같은 집합에 속해 있는지 확인(Find)할 때 사용한다.
크루스칼 알고리즘에서 사이클을 판별할 때 Union-Find를 사용한다.
초기에는 각 정점이 서로 다른 집합에 속해 있다.
이 상태에서 각 간선에 대해 다음의 연산을 수행한다.
MST에 이미 포함된 정점의 집합과 포함되지 않은 정점의 집합으로 나누어 정점을 관리한다.
정점을 하나씩 선택하며 MST 집합에 포함시킴으로써, MST 집합을 단계적으로 확장해 나가는 방법이다.
크루스칼 알고리즘을 사용해 구현하였다.
import java.io.*;
import java.util.*;
class Main {
static int[] parent;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int V = Integer.parseInt(st.nextToken());
int E = Integer.parseInt(st.nextToken());
// 간선을 가중치 기준 오름차순 정렬
PriorityQueue<Edge> edges = new PriorityQueue<>((e1, e2) -> e1.weight - e2.weight);
parent = new int[V + 1];
// 모든 정점의 루트를 자기 자신으로 초기화
for (int i = 1; i <= V; i++) {
parent[i] = i;
}
for (int i = 0; i < E; i++) {
st = new StringTokenizer(br.readLine());
int A = Integer.parseInt(st.nextToken());
int B = Integer.parseInt(st.nextToken());
int C = Integer.parseInt(st.nextToken());
edges.add(new Edge(A, B, C));
}
int minWeight = 0;
int edgeCount = 0;
while (edgeCount < V - 1) {
Edge curr = edges.poll();
// find
int rootA = find(curr.nodeA);
int rootB = find(curr.nodeB);
if (rootA != rootB) {
parent[rootB] = rootA; // union
minWeight += curr.weight;
edgeCount++;
}
}
System.out.println(minWeight);
}
static class Edge {
int nodeA;
int nodeB;
int weight;
Edge(int nodeA, int nodeB, int weight) {
this.nodeA = nodeA;
this.nodeB = nodeB;
this.weight = weight;
}
}
private static int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]);
}
return parent[x];
}
}