[백준] 1967번 : 트리의 지름 (JAVA)

인간몽쉘김통통·2024년 1월 10일

백준

목록 보기
46/92

문제



이해

트리의 지름은 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.

가중치가 있는 트리를 입력받고 트리의 지름을 구하여라.

접근

처음에는 트리의 지름을 다르게 생각했다.

트리의 모양에서 트리의 지름을 가지는 노드는 공통 조상을 기준으로 생각할 수 있기 때문에 한 노드에서 트리의 지름이 될 수 있는 거리를 측정하고 이를 비교하기로 하였다.

이 방법은 서브트리의 지름을 측정하고 이를 DP로 처리하여 최종값을 구하는 방법이었다.

하지만 요구하는 정보도 많았고 무엇보다도 시간문제가 발생할 것 같아 다른 방법을 생각하기로 하였다.


예제의 그림에서 힌트를 얻을 수 있었는데 트리의 지름은 양쪽에서 당길 때 모든 노드가 경로 안에 포함된다.

그렇다면 우리가 노드 탐색의 기준으로 삼을 루트 역시도 원 안에 포함되기 때문에 루트로부터 가장 먼 노드를 탐색한다.

루트로부터 먼 노드를 구했다면 해당 노드는 지름의 기준이 되는 노드 중 하나이다.

따라서, 이 노드를 기준으로 다시 가장 먼 노드를 구한다면? 이 때의 가중치의 합이 트리의 지름이 되는 것이다.

정리해보면 다음과 같다.
1. 루트를 시작으로 가장 먼 노드 탐색 (DFS)
2. 이전에 구했던 노드에서부터 다시 가장 먼 노드 탐색 (DFS)
3. 최대 거리 출력

코드

package java_baekjoon;

import java.util.*;
import java.io.*;

public class prob1967 {
    static class vertex {
        int num;
        int distance;

        public vertex(int num, int distance) {
            this.num = num;
            this.distance = distance;
        }
    }

    static int N;
    static ArrayList<vertex>[] v_list;
    static boolean[] visited;
    static int max = 0;
    static int max_idx = 0;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));

        N = Integer.parseInt(br.readLine());

        v_list = new ArrayList[N + 1];
        for (int i = 0; i <= N; i++) {
            v_list[i] = new ArrayList<>();
        }

        for (int i = 0; i < N - 1; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int parent = Integer.parseInt(st.nextToken());
            int child = Integer.parseInt(st.nextToken());
            int distance = Integer.parseInt(st.nextToken());
            v_list[parent].add(new vertex(child, distance));
            v_list[child].add(new vertex(parent, distance));
        }
        visited = new boolean[N + 1];
        visited[1] = true;
        dfs(1, 0);

        visited = new boolean[N + 1];
        visited[max_idx] = true;
        dfs(max_idx, 0);
        System.out.println(max);
    }

    static void dfs(int idx, int sum) {
        if (max < sum) {
            max = sum;
            max_idx = idx;
        }

        for (vertex v : v_list[idx]) {
            if (!visited[v.num]) {
                visited[v.num] = true;
                dfs(v.num, sum + v.distance);
            }
        }
    }
}

결과

profile
SW 0년차 개발자입니다.

0개의 댓글