[백준] 1197번 : 최소 스패닝 트리

김건우·2023년 9월 28일
0

문제 풀이

목록 보기
29/62

최소 스패닝 트리


풀이 방법

우선 기본 지식을 알아 보겠다.

스패닝 트리 or 신장 트리

  • 원래 그래프의 정점 전부와 간선의 부분 집합으로 구성된 부분 그래프
  • 스패닝 트리에 포함된 간선들은 정점들을 트리 형태로 전부 연결해야 함.(간선들이 사이클을 이루지 않아야 함.)

최소 스패닝 트리(MST)

  • 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리

이 문제를 해결하는 방법으로는 크게 2가지 알고리즘이 있다.

1. 크루스칼 알고리즘

크루스칼 알고리즘은 상호 배타적 집합 자료구조(union-find)의 좋은 예시이다.
가중치가 가장 적은 간선일수록 MST에 포함된다는 발상으로 만들어 졌다.

  • 그래프의 모든 간선을 가중치의 오름차순으로 정렬한 뒤, 스패닝 트리에 하나 씩 추가한다.
  • 사이클을 이루는지 체크해서 사이클을 이루면 추가하지 않는다.
  • 이 과정을 모든 간선에 대해서 반복한다.

2. 프림 알고리즘

하나의 시작 정점을 기준으로 가장 작은 간선과 연결된 정점을 선택해 MST가 될 때까지 모든 노드를 연결시킨다.
크루스칼 알고리즘과 다르게 모든 과정에서 항상 연결된 트리 형태를 이룬다.

  • 임의의 간선을 선택한다.
  • 선택한 간선의 정점으로부터 가장 낮은 가중치를 갖는 정점을 선택한다.
  • 모든 정점에 대하여 반복한다.

크루스칼 vs 프림

크루스칼 알고리즘프림 알고리즘
탐색 방법간선 위주정점 위주
탐색 과정시작점 따로 지정없이 최소 비용의 간선을 차례대로 대입하면서 사이클이 이루어지기 전까지 탐색시작점을 지정한 후 가까운 정점을 선택하면서 모든 정점을 탐색
사용간선의 개수가 적은 경우 크루스칼 알고리즘이 용이간선의 개수가 많은 경우에는 정점 위주 탐색인 프림이 용이
시간복잡도O(ElogV)O(ElogV)

참고 블로그 https://loosie.tistory.com/159

느낀 점

이번에는 최소 스패닝 트리(MST)를 구하는 문제를 풀어보았다. 크루스칼, 프림 알고리즘을 통해서 구현해 보았다. 많이 들어본 알고리즘 이였지만 막상 어떤 것인지 정리하거나, 구현해보지 않았었다. 이번 문제를 통해 확실히 정리하고 가야겠다!


소스 코드

크루스칼 알고리즘

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

class Node implements Comparable<Node>{
    int from;
    int to;
    int weight;
    public Node(int from, int to, int weight) {
        this.from = from;
        this.to = to;
        this.weight = weight;
    }

    @Override
    public int compareTo(Node o) {
        return this.weight-o.weight;
    }
}
public class Main {
    static int[] parents;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine()," ");

        int v = Integer.parseInt(st.nextToken());
        int e = Integer.parseInt(st.nextToken());

        parents = new int[v+1];
        //자기 자신을 가리키도록 초기화
        for(int i=1;i<=v;i++){
            parents[i] = i;
        }

        ArrayList<Node> nodeList = new ArrayList<>();
        for(int i=0;i<e;i++){
            st = new StringTokenizer(br.readLine()," ");
            int from = Integer.parseInt(st.nextToken());
            int to = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());

            nodeList.add(new Node(from,to,weight));
        }

        Collections.sort(nodeList);

        int sum = 0;
        int count = 0;
        for(Node n : nodeList){
            if(union(n.from, n.to)){
                sum += n.weight;
                count++;

                if(count==e-1) break;
            }
        }
        System.out.println(sum);
    }

    private static boolean union(int from, int to) {
        int fromRoot = find(from);
        int toRoot = find(to);

        //두 노드가 같은 트리에 속해 있다면
        if(fromRoot==toRoot)
            return false;
        //fromRoot를 toRoot의 부모로 설정해 두 트리를 합침
        else {
            parents[toRoot] = fromRoot;
            return true;
        }
    }

    private static int find(int v) {
        if(parents[v]==v) return v; //해당 노드가 루트노드인지
        else return parents[v] = find(parents[v]);
    }
}

프림 알고리즘

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

class Node implements Comparable<Node>{
    int vertex;
    int weight;

    public Node(int vertex, int weight) {
        this.vertex = vertex;
        this.weight = weight;
    }

    @Override
    public int compareTo(Node o) {
        return this.weight-o.weight;
    }
}

public class Main {
    static ArrayList<Node>[] list;
    static boolean[] visited;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine()," ");

        int v = Integer.parseInt(st.nextToken());
        int e = Integer.parseInt(st.nextToken());

        visited = new boolean[v+1];
        list = new ArrayList[v+1];
        for(int i=0;i<=v;i++){
            list[i] = new ArrayList<>();
        }

        for(int i=0;i<e;i++){
            st = new StringTokenizer(br.readLine()," ");
            int from = Integer.parseInt(st.nextToken());
            int to = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());

            list[from].add(new Node(to,weight));
            list[to].add(new Node(from,weight));
        }

        int result = prim(1);
        System.out.println(result);
    }

    private static int prim(int start) {
        PriorityQueue<Node> pq = new PriorityQueue<>();
        int total = 0;
        pq.add(new Node(start,0));

        while(!pq.isEmpty()){
            Node p = pq.remove();

            if(visited[p.vertex]) continue;

            visited[p.vertex] = true;
            total += p.weight;

            for(Node next : list[p.vertex]){
                if(!visited[next.vertex]){
                    pq.add(next);
                }
            }
        }
        return total;
    }
}
profile
공부 정리용

0개의 댓글