우선 기본 지식을 알아 보겠다.
크루스칼 알고리즘은 상호 배타적 집합 자료구조(union-find)의 좋은 예시이다.
가중치가 가장 적은 간선일수록 MST에 포함된다는 발상으로 만들어 졌다.
하나의 시작 정점을 기준으로 가장 작은 간선과 연결된 정점을 선택해 MST가 될 때까지 모든 노드를 연결시킨다.
크루스칼 알고리즘과 다르게 모든 과정에서 항상 연결된 트리 형태를 이룬다.
크루스칼 알고리즘 | 프림 알고리즘 | |
---|---|---|
탐색 방법 | 간선 위주 | 정점 위주 |
탐색 과정 | 시작점 따로 지정없이 최소 비용의 간선을 차례대로 대입하면서 사이클이 이루어지기 전까지 탐색 | 시작점을 지정한 후 가까운 정점을 선택하면서 모든 정점을 탐색 |
사용 | 간선의 개수가 적은 경우 크루스칼 알고리즘이 용이 | 간선의 개수가 많은 경우에는 정점 위주 탐색인 프림이 용이 |
시간복잡도 | O(ElogV) | O(ElogV) |
이번에는 최소 스패닝 트리(MST)를 구하는 문제를 풀어보았다. 크루스칼, 프림 알고리즘을 통해서 구현해 보았다. 많이 들어본 알고리즘 이였지만 막상 어떤 것인지 정리하거나, 구현해보지 않았었다. 이번 문제를 통해 확실히 정리하고 가야겠다!
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
class Node implements Comparable<Node>{
int from;
int to;
int weight;
public Node(int from, int to, int weight) {
this.from = from;
this.to = to;
this.weight = weight;
}
@Override
public int compareTo(Node o) {
return this.weight-o.weight;
}
}
public class Main {
static int[] parents;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine()," ");
int v = Integer.parseInt(st.nextToken());
int e = Integer.parseInt(st.nextToken());
parents = new int[v+1];
//자기 자신을 가리키도록 초기화
for(int i=1;i<=v;i++){
parents[i] = i;
}
ArrayList<Node> nodeList = new ArrayList<>();
for(int i=0;i<e;i++){
st = new StringTokenizer(br.readLine()," ");
int from = Integer.parseInt(st.nextToken());
int to = Integer.parseInt(st.nextToken());
int weight = Integer.parseInt(st.nextToken());
nodeList.add(new Node(from,to,weight));
}
Collections.sort(nodeList);
int sum = 0;
int count = 0;
for(Node n : nodeList){
if(union(n.from, n.to)){
sum += n.weight;
count++;
if(count==e-1) break;
}
}
System.out.println(sum);
}
private static boolean union(int from, int to) {
int fromRoot = find(from);
int toRoot = find(to);
//두 노드가 같은 트리에 속해 있다면
if(fromRoot==toRoot)
return false;
//fromRoot를 toRoot의 부모로 설정해 두 트리를 합침
else {
parents[toRoot] = fromRoot;
return true;
}
}
private static int find(int v) {
if(parents[v]==v) return v; //해당 노드가 루트노드인지
else return parents[v] = find(parents[v]);
}
}
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
class Node implements Comparable<Node>{
int vertex;
int weight;
public Node(int vertex, int weight) {
this.vertex = vertex;
this.weight = weight;
}
@Override
public int compareTo(Node o) {
return this.weight-o.weight;
}
}
public class Main {
static ArrayList<Node>[] list;
static boolean[] visited;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine()," ");
int v = Integer.parseInt(st.nextToken());
int e = Integer.parseInt(st.nextToken());
visited = new boolean[v+1];
list = new ArrayList[v+1];
for(int i=0;i<=v;i++){
list[i] = new ArrayList<>();
}
for(int i=0;i<e;i++){
st = new StringTokenizer(br.readLine()," ");
int from = Integer.parseInt(st.nextToken());
int to = Integer.parseInt(st.nextToken());
int weight = Integer.parseInt(st.nextToken());
list[from].add(new Node(to,weight));
list[to].add(new Node(from,weight));
}
int result = prim(1);
System.out.println(result);
}
private static int prim(int start) {
PriorityQueue<Node> pq = new PriorityQueue<>();
int total = 0;
pq.add(new Node(start,0));
while(!pq.isEmpty()){
Node p = pq.remove();
if(visited[p.vertex]) continue;
visited[p.vertex] = true;
total += p.weight;
for(Node next : list[p.vertex]){
if(!visited[next.vertex]){
pq.add(next);
}
}
}
return total;
}
}