[백준] 1197번 : 최소 스패닝 트리 - JAVA [자바]

doxxx·2023년 5월 17일
0

백준

목록 보기
3/7
post-thumbnail

링크

관찰

출력에 적힌대로 최소 스패닝 트리의 가중치를 구하면 되는 문제이다.

보통 MST를 구하는 방법은 크루스칼 알고리즘과 프림 알고리즘이 있다.

MST

일단 MST란 무엇인지 알아보자.

MST(Minimum Spanning Tree)는 가능한 모든 Spanning Tree중에서 최소의 가중치를 갖는 트리를 의미한다.

Spanning Tree는 연결 유한 그래프의 모든 정점을 포함하는 트리를 의미한다. 이러한 스패닝 트리의 속성에 "최소 가중치"를 가져야 한다는 조건을 추가하면 MST가 되는 것을 알 수 있다.

관찰에 언급한대로, 이 MST를 다항 시간 내에 찾을 수 있는 알고리즘으로는 크루스칼 알고리즘과 프림 알고리즘이 있다.

크루스칼 알고리즘

크루스칼 알고리즘은 주어진 그래프의 모든 Edge들의 가중치를 오름차순으로 정렬한다.

이후, 사이클을 형성하지 않는 선에서 지금까지 만든 MST에 Edge를 하나씩 하나씩 추가한다.

모든 Edge들을 탐색하며 greedy하게 선택하며, Edge가 전체 노드의 개수 - 1개가 될 때 까지 Edge를 추가하는 과정을 반복한다.

크루스칼 알고리즘을 따르게 되면, 단 1개의 MST를 가지게 된다.

보통 E를 Edge의 개수, V를 Node의 개수라고 하게 되면, 시간 복잡도는 O(ElogV)안에 동작한다고 증명된다.

E는 최대 V^2개일 수 있기 때문에, log 연산을 하게 될 경우, logE == logV^2 == 2logV임을 이용한다.

이에 필요한 자료구조는 서로소 집합(disjoint-set)이고, 이러한 서로소 집합을 서로 합치는 union과정과 어떤 Node가 어떤 서로소 집합에 속하는지 탐색하는 find과정을 이용하여 구한다.

흔하게 알려진 union-find 알고리즘이다.

프림 알고리즘

프림 알고리즘은 하나의 Node를 선택하여 트리를 만드는 것으로 시작한다.

그래프의 모든 Edge들이 들어있는 집합을 만들고, 모든 Node가 트리에 속해있지 않는 선에서, 지금까지 만들어진 트리와 연결된 Edge들 중에서 트리 속의 두 Node를 연결하지 않는 가중치가 가장 작은 Edge를 트리에 추가 하게 된다.

크루스칼 알고리즘과 같이 Edge의 개수가 전체 노드의 개수 - 1개가 될때 까지 위의 과정을 반복한다.

프림 알고리즘을 구현하는 방법에는 최소 힙(보통 우선 순위 큐를 이용하여 구현), 이진 힙, 피보나치 힙 등이 있다.

물론 각각의 방식은 서로 다른 시간 복잡도를 가지고, 그래프의 형태에 따라서 유리한 방식이 다르다.

비교

  1. 시작: 크루스칼 알고리즘은 최소 가중치를 갖는 Node에서 시작하지만, 프림 알고리즘은 모든 정점을 고려한다.
  2. 순회: 크루스칼 알고리즘은 1개의 노드에 대해서 1번만 순회하지만, 프림 알고리즘은 2번 이상 순회하게 된다.
  3. 시간 복잡도: 프림 알고리즘은 어떤 자료구조를 이용하느냐에 따라서 차이가 나지만 크루스칼 알고리즘의 경우 변함이 없다.
  4. 적합성: 프림 알고리즘은 밀도가 높은 그래프에, 크루스칼 알고리즘은 밀도가 적은 그래프에서 더 빠르게 동작한다.

비고

프림 알고리즘의 경우, 외판원 문제와 모든 도시를 연결하는 도로나 철로에 대한 문제에 적용된다.

프림 알고리즘은 그 유명한 다이크스트라와 프림이 함께 재발견하여서, Prim-Dijkstra 알고리즘으로도 알려져 있다. 프림 알고리즘과 다이크스트라의 알고리즘은 코드 상에서도 유사한 점이 많지만, 서로 전-혀 다른 목적으로 사용되므로 두개를 혼동하지 않도록 한다.

크루스칼 알고리즘의 경우, LAN이나 TV 네트워크에 적용된다.

코드

필자는 크루스칼 알고리즘을 이용해 구현했다.

우선순위 큐는 그냥 단순하게, 가중치대로 정렬하기 위해 사용했다.

그냥 ArrayList를 사용하고 정렬해도 된다.

import java.io.*;
import java.util.*;

public class Main {

    static int v;
    static int e;

    static int[] parent;
    static int[] rank;

    static PriorityQueue<Edge> edges;

    static int find(int x) {
        if (parent[x] == x) {
            return x;
        }
        return parent[x] = find(parent[x]);
    }

    static void union(int x, int y) {
        x = find(x);
        y = find(y);

        if (x == y) {
            return;
        }

        if (rank[x] < rank[y]) {
            parent[x] = y;
        } else {
            parent[y] = x;

            if (rank[x] == rank[y]) {
                rank[x]++;
            }
        }
    }


    static class Edge implements Comparable<Edge> {

        int from;
        int to;
        int weight;

        public Edge(int from, int to, int weight) {
            this.from = from;
            this.to = to;
            this.weight = weight;
        }

        @Override
        public int compareTo(Edge o) {
            return this.weight - o.weight;
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        v = Integer.parseInt(st.nextToken());
        e = Integer.parseInt(st.nextToken());

        parent = new int[v + 1];
        rank = new int[v + 1];
        edges = new PriorityQueue<>();

        for (int i = 1; i <= v; i++) {
            parent[i] = i;
        }

        for (int i = 0; i < e; i++) {
            st = new StringTokenizer(br.readLine());
            int from = Integer.parseInt(st.nextToken());
            int to = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());

            edges.add(new Edge(from, to, weight));
        }

        int sum = 0;

        while (!edges.isEmpty()) {
            Edge edge = edges.poll();

            if (find(edge.from) == find(edge.to)) {
                continue;
            }

            union(edge.from, edge.to);
            sum += edge.weight;
        }

        System.out.println(sum);
    }
}

0개의 댓글