MST 알고리즘 문제.
일단의 프림 기반으로 풀긴 했으나, 프림은 아니다. 머릿속에 프림이 잘 안떠올라서, 뭔가 우선순위 큐 있으면서 완탐했던 거 같은데.. 하면서 만들어진 코드이다.
pq
를 활용한 BFS()
로, BFS()
시작전 발전소가 있는 마을을 미리 넣어둔다. 넣을 때 만큼 카운트도 해주자.
가장 dist
가 짧은 순으로 탐색한다.
다음 노드로 넘어가기 전, 다음 노드에서 다다음 노드로 넘어가는 모든 노드를 우선 순위큐에 반영한다. 모든 노드가 연결되는 경우 ( 만큼의 poll()
이 이루어졌을때) 누적한 dist
를 반환한다.
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.StringTokenizer;
public class Main {
static class Node {
int curr;
int next;
int dist;
public Node(int curr, int next, int dist) {
this.curr = curr;
this.next = next;
this.dist = dist;
}
}
static List<Node> adjList[];
static boolean[] visited;
static PriorityQueue<Node> pq;
static int N, M, K;
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
// 노드 개수
N = Integer.parseInt(st.nextToken());
// 간선 개수
M = Integer.parseInt(st.nextToken());
// 발전소 개수
K = Integer.parseInt(st.nextToken());
int[] plant = new int[K];
st = new StringTokenizer(br.readLine());
for (int i = 0; i < K; i++) {
plant[i] = Integer.parseInt(st.nextToken());
}
adjList = new ArrayList[N + 1];
for (int i = 1; i <= N; i++) {
adjList[i] = new ArrayList<>();
}
for (int i = 0; i < M; i++) {
st = new StringTokenizer(br.readLine());
int prev = Integer.parseInt(st.nextToken());
int next = Integer.parseInt(st.nextToken());
int dist = Integer.parseInt(st.nextToken());
adjList[prev].add(new Node(prev, next, dist));
adjList[next].add(new Node(next, prev, dist));
}
visited = new boolean[N + 1];
pq = new PriorityQueue<>((n1, n2) -> n1.dist - n2.dist);
for (int idx : plant) {
visited[idx] = true;
for (Node nNode : adjList[idx]) {
pq.add(nNode);
}
}
int d = K;
int dist = 0;
while(!pq.isEmpty()) {
if(d>=N) {
System.out.println(dist);
break;
}
Node curr = pq.poll();
if(visited[curr.next]) continue;
for(Node nNode : adjList[curr.next]) {
if(visited[nNode.curr]) continue;
visited[nNode.curr] = true;
pq.add(nNode);
for(Node nNode2 : adjList[nNode.curr]) pq.add(nNode2);
dist+=curr.dist;
d++;
}
}
}
}