[백준]#1197 최소 스패닝 트리

SeungBird·2020년 8월 26일
0

⌨알고리즘(백준)

목록 보기
42/401

문제

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

입력

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

최소 스패닝 트리의 가중치가 -2147483648보다 크거나 같고, 2147483647보다 작거나 같은 데이터만 입력으로 주어진다.

출력

첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.

예제 입력 1

3 3
1 2 1
2 3 2
1 3 3

예제 출력 1

3

풀이

이 문제는 '[SWEA]#1251 하나로' 문제와 마찬가지 MST문제로 프림 알고리즘과 우선순위 큐를 사용해서 풀었다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class Main {

    static boolean[] visited;
    static ArrayList<Node>[] nodeList;

    public static void main(String args[]) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] str = br.readLine().split(" ");
        int N = Integer.parseInt(str[0]);
        int E = Integer.parseInt(str[1]);
        nodeList = new ArrayList[N+1];
        visited = new boolean[N+1];

        for(int i=1; i<N+1; i++) {
            nodeList[i] = new ArrayList<>();
        }

        for(int i=0; i<E; i++) {
            String[] str1 = br.readLine().split(" ");
            nodeList[Integer.parseInt(str1[0])].add(new Node(Integer.parseInt(str1[0]), Integer.parseInt(str1[1]), Double.parseDouble(str1[2])));
            nodeList[Integer.parseInt(str1[1])].add(new Node(Integer.parseInt(str1[1]), Integer.parseInt(str1[0]), Double.parseDouble(str1[2])));
        }

        double answer = MST();
        System.out.println(Math.round(answer));
    }

    public static double MST() {
        PriorityQueue<Node> pq = new PriorityQueue<>();
        ArrayList<Node> tempList;
        Node tempNode;
        Queue<Integer> queue = new LinkedList<>();
        double answer = 0;

        queue.add(1);
        while(!queue.isEmpty()) {
            int currentNode = queue.poll();
            visited[currentNode] = true;
            tempList = nodeList[currentNode];

            for(int i=0; i<tempList.size(); i++) {
                if(!visited[tempList.get(i).end]) {
                    pq.add(tempList.get(i));
                }
            }

            while(!pq.isEmpty()) {
                tempNode = pq.poll();

                if(!visited[tempNode.end]) {
                    visited[tempNode.end]=true;
                    answer += tempNode.cost;
                    queue.add(tempNode.end);
                    break;
                }
            }
        }
        return answer;
    }

    static class Node implements Comparable<Node>{
        int start;
        int end;
        double cost;

        public Node(int start, int end, double cost) {
            this.start = start;
            this.end = end;
            this.cost = cost;
        }

        @Override
        public int compareTo(Node node) {
            return this.cost > node.cost ? 1 : -1;
        }
    }
}

(+추가) 프림 알고리즘 말고 크루스칼 알고리즘으로도 다시 풀어보았다

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.BufferedReader;
import java.util.PriorityQueue;

public class Main {
    static PriorityQueue<Node> queue = new PriorityQueue<>();
    static int[] parent;

    public static void main(String[] args) throws IOException{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] input = br.readLine().split(" ");
        int V = Integer.parseInt(input[0]);
        int E = Integer.parseInt(input[1]);
        parent = new int[V+1];
        for(int i=1; i<=V; i++)
            parent[i] = i;

        for(int i=0; i<E; i++) {
            input = br.readLine().split(" ");
            queue.add(new Node(Integer.parseInt(input[0]), Integer.parseInt(input[1]), Integer.parseInt(input[2])));
        }

        System.out.println(kruskal());
    }

    public static int kruskal() {
        int res = 0;

        while(!queue.isEmpty()) {
            Node temp = queue.poll();
            int start = temp.start;
            int end = temp.end;
            int cost = temp.cost;

            int s = find(start);
            int e = find(end);

            if(s==e) continue;

            union(start, end);
            res += cost;
        }
        return res;
    }

    public static class Node implements Comparable<Node>{
        int start;
        int end;
        int cost;

        public Node(int start, int end, int cost) {
            this.start = start;
            this.end = end;
            this.cost = cost;
        }

        @Override
        public int compareTo(Node n) {
            return this.cost>n.cost ? 1 : -1;
        }
    }

    public static void union(int a, int b) {
        a = find(a);
        b = find(b);

        if(a<b) parent[b] = a;
        else parent[a] = b;
    }

    public static int find(int x) {
        if(parent[x] == x)
            return x;

        return find(parent[x]);
    }
}
profile
👶🏻Jr. Programmer

0개의 댓글