[Java] 백준 1197번: 최소 스패닝 트리

U·2025년 2월 22일

백준

목록 보기
89/116

[문제 바로 가기] - 최소 스패닝 트리

유형 : 최소 신장 트리(MST)

💡 접근 방식

문제에 나와있다시피, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 최소 신장(스패닝) 트리라고 한다.

MST는 1. 크루스칼 알고리즘2. 프림 알고리즘으로 풀이할 수 있는데 두 풀이 다 익숙하지 않아 둘 다 풀이해봤다.

먼저 최소 신장 트리는 그래프에서 최소 비용을 구하는 문제다. 예를 들면 문제처럼 모든 정점을 연결하는 간선들의 가중치의 합이 최소가 되는 트리거나, 두 정점 사이의 최소 경로 찾기(최단경로)일 때를 의미한다. 이때 무향 가중치 그래프여야 한다.

KRUSKAL 알고리즘

  • 간선을 하나씩 선택해서 MST를 찾는 알고리즘
  1. 모든 간선을 가중치에 따라 오름차순 정렬
  2. 가중치가 가장 낮은 간선부터 선택하면서 트리를 증가
  3. 이때 사이클이 존재하면 남아 있는 간선 중 다음으로 가중치가 낮은 간선 선택
  4. 2를 반복

PRIM 알고리즘

  • 정점을 중심으로 간선을 보며 풀어나가는 알고리즘 : 인접 행렬 / 인접 리스트
  • 하나의 정점에서 연결된 간선들 중에 하나씩 선택하면서 MST를 만들어가는 방식
  1. 임의 정점을 하나 선택해서 시작
  2. 선택한 정점과 인접하는 정점들 중의 최소 비용의 간선이 존재하는 정점 선택
  3. 모든 정점이 선택될 때까지 1, 2 반복

★ 꿀팁으로 간선이 많을 경우에는 프림, 간선이 적을 경우에는 크루스칼을 사용하는 것이 좋다. 크루스칼은 간선이 많으면 정렬에 시간이 많이 걸리기 때문이다!


풀이

크루스칼 알고리즘 사용

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

/**
 * 백준 1197번 최소 스패닝 트 리 
 * - 크루스칼/프림
 */

public class Main {
	public static class Node implements Comparable<Node> {
		int to, from, value;
		
		Node(int to, int from, int value) {
			this.to = to;
			this.from = from;
			this.value = value;
		}
		
		@Override
		public int compareTo(Node o) {
			return Integer.compare(this.value, o.value);
		}
	}
	
	static int parent[];
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		
		int V = Integer.parseInt(st.nextToken());
		int E = Integer.parseInt(st.nextToken());
		
		PriorityQueue<Node> queue = new PriorityQueue<>();
		
		parent = new int[V + 1];
		for (int i = 1; i <= V; i++) parent[i] = i;
		
		for (int i = 0; i < E; i++) {
			st = new StringTokenizer(br.readLine());
			int A = Integer.parseInt(st.nextToken());
			int B = Integer.parseInt(st.nextToken());
			int C = Integer.parseInt(st.nextToken());
			
			queue.add(new Node(A, B, C));
		}
		
		int sum = 0;
		while (!queue.isEmpty()) {
			Node cur = queue.poll();
			
			int toParent = find(cur.to);
			int fromParent = find(cur.from);
			
			if (toParent != fromParent) {
				sum += cur.value;
				union(cur.to, cur.from);
			}
		}
		
		System.out.println(sum);
	}
	
	public static int find(int node) {
		if (parent[node] == node) return node;
		return parent[node] = find(parent[node]);
	}
	
	public static void union(int to, int from) {
		int toParent = find(to);
		int fromParent = find(from);
		
		if (to != from) parent[fromParent] = toParent;
	}
}

프림 알고리즘 사용

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

/**
 * 백준 1197번 최소 스패닝 트 리 
 * - 크루스칼/프림
 */

public class Main {
	public static class Node implements Comparable<Node> {
		int node, value;
		
		Node(int node, int value) {
			this.node = node;
			this.value = value;
		}
		
		@Override
		public int compareTo(Node o) {
			return Integer.compare(this.value, o.value);
		}
	}
	
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		
		int V = Integer.parseInt(st.nextToken());
		int E = Integer.parseInt(st.nextToken());
		
		List<Node>[] list = new ArrayList[V + 1];
		for (int i = 0; i <= V; i++) {
			list[i] = new ArrayList<>();
		}
		
		for (int i = 0; i < E; i++) {
			st = new StringTokenizer(br.readLine());
			int A = Integer.parseInt(st.nextToken());
			int B = Integer.parseInt(st.nextToken());
			int C = Integer.parseInt(st.nextToken());
			
			list[A].add(new Node(B, C));
			list[B].add(new Node(A, C));
		}
		
		int sum = 0;
		boolean[] visited = new boolean[V + 1];
		PriorityQueue<Node> queue = new PriorityQueue<>();
		queue.add(new Node(1, 0));
		
		while (!queue.isEmpty()) {
			Node cur = queue.poll();
			
			int node = cur.node;
			int value = cur.value;
			
			if (visited[node]) continue;
			visited[node] = true;
			
			sum += value;
			
			for (Node next : list[node]) {
				if (!visited[next.node]) {
					queue.add(next);
				}
			}
		}
		
		System.out.println(sum);
	}
}
profile
백엔드 개발자 연습생

0개의 댓글