https://www.acmicpc.net/problem/1922
기본적인 MST 문제
프림 알고리즘으로 해결했다.
pq 써서 visit 했는 지 확인하면서 다 넣어주면 된다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.PriorityQueue;
class Pair implements Comparable<Pair>{
int dest;
int cost;
public Pair(int dest, int cost) {
this.dest = dest;
this.cost = cost;
}
public int compareTo(Pair o) {
return cost - o.cost;
}
}
public class Main {
public static void main(String[] args) throws IOException {
Main main = new Main();
BufferedReader br = main.getReader();
int n = main.getInt(br)[0];
int m = main.getInt(br)[0];
ArrayList<Pair>[] arr = new ArrayList[n+1];
for (int i=1; i<=n; i++) {
arr[i] = new ArrayList<Pair>();
}
for (int i=0; i<m; i++) {
int[] temp = main.getInt(br);
arr[temp[0]].add(new Pair(temp[1], temp[2]));
arr[temp[1]].add(new Pair(temp[0], temp[2]));
}
int ans = main.solve(n, m, arr);
System.out.println(ans);
}
private int solve(int n, int m, ArrayList<Pair>[] arr) {
PriorityQueue<Pair> pq = new PriorityQueue<Pair>();
HashSet<Integer> visit = new HashSet<Integer>();
int cur = 1;
int ans = 0;
while (visit.size() < n) {
visit.add(cur);
arr[cur].forEach(pq::add);
while (!pq.isEmpty()) {
Pair poll = pq.poll();
if (!visit.contains(poll.dest)) {
cur = poll.dest;
ans = ans + poll.cost;
break;
}
}
}
return ans;
}
private BufferedReader getReader() {
return new BufferedReader(new InputStreamReader(System.in));
}
private int[] getInt(BufferedReader br) throws IOException {
return Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray();
}
}
import sys, heapq
r=sys.stdin.readline
n=int(r())
m=int(r())
graph=[[] for _ in range(n+1)]
for _ in range(m):
a,b,c=map(int,r().split())
graph[a].append((b,c))
graph[b].append((a,c))
pq = []
visit = set()
cur = 1
next = cur
ans = cost = 0
while len(visit) != n:
visit.add(cur)
for b, c in graph[cur]:
heapq.heappush(pq, (c, b))
while pq:
cost, next = heapq.heappop(pq)
if not next in visit:
cur = next
ans += cost
break
print(ans)