백준 13511 트리와쿼리2

뀨뀨찬찬·2021년 10월 13일
0

알고리즘

목록 보기
2/12

문제

N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다.
아래의 두 쿼리를 수행하는 프로그램을 작성하시오.
1 u v: u에서 v로 가는 경로의 비용을 출력한다.
2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.

입력

첫째 줄에 N (2 ≤ N ≤ 100,000)이 주어진다.
둘째 줄부터 N-1개의 줄에는 i번 간선이 연결하는 두 정점 번호 u와 v와 간선의 비용 w가 주어진다.
다음 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.
다음 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다.
간선의 비용은 항상 1,000,000보다 작거나 같은 자연수이다.

출력

각각의 쿼리의 결과를 순서대로 한 줄에 하나씩 출력한다.

풀이

두개의 동작을 하는 코드를 짜야한다.
1. u -> v 의 경로 비용 구하기
2. u -> v 의 경로 중 k 번째 노드 찾기

1을 위해 dfs를 진행하며 루트 노드로부터의 비용을 저장한 dist[] 배열을 사용한다.
두 노드간 거리는 dist[u] + dist[v] - 2 * dist[lca]가 된다.
이건 직접 그려보면 더 이해가 빠르다.

2에서 lca를 기준으로 k번째가 왼쪽에 있는지 오른쪽에 있는지 판단하고 lca 알고리즘을 수행했던 것처럼 top-down 방식으로 탐색해준다.

lca 자체가 어려운 알고리즘이라 계속해서 복기하는 게 필요할 것 같다.

import java.io.*;
import java.util.*;

public class Main {
    static List[] tree;
    static int[][] parent; // parent[i][j] : j의 2^i 번째 부모
    static int[] depth;
    static long[] dist;
    static int N, S, M;
    static class Edge {
        int to;
        long dist;

        public Edge(int to, long dist) {
            this.to = to;
            this.dist = dist;
        }
    }
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringBuilder sb = new StringBuilder();

        N = Integer.parseInt(br.readLine());
        S = 0;
        for (int i = 1; i <= N; i *= 2) {
            S++;
        }
        parent = new int[S][N + 1];
        tree = new List[N + 1];
        depth = new int[N + 1];
        dist = new long[N + 1];
        for (int i = 1; i <= N; i++) {
            tree[i] = new ArrayList<Edge>();
        }
        StringTokenizer st;
        for (int i = 0; i < N - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            long w = Long.parseLong(st.nextToken());
            tree[u].add(new Edge(v, w));
            tree[v].add(new Edge(u, w));
        }
        dfs(1, 1);
        for (int i = 1; i < S; i++) {
            for (int j = 1; j <= N; j++) { // 2^i 번째 부모는 2^i-1번째 부모의 2^i-1번째 부모
                parent[i][j] = parent[i - 1][parent[i - 1][j]];
            }
        }
        M = Integer.parseInt(br.readLine());
        for (int i = 0; i < M; i++) {
            st = new StringTokenizer(br.readLine());
            int type = Integer.parseInt(st.nextToken());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            int lca = lca(a, b);
            if(type == 1) {
                sb.append(dist[a] + dist[b] - 2 * dist[lca]).append("\n");
            } else {
                int k = Integer.parseInt(st.nextToken());
                sb.append(getKth(a, b, lca, k)).append("\n");
            }
        }
        bw.write(sb.toString());
        bw.flush();
        bw.close();
    }
    static int getKth(int a, int b, int lca, int k) {
        if(k == depth[a] - depth[lca] + 1) return lca;
        if(k < depth[a] - depth[lca] + 1) { // a ~ lca에서 탐색
            int depthK = depth[a] - k + 1;
            for (int i = S - 1; i >= 0; i--) {
                if (depthK <= depth[parent[i][a]]) {
                    a = parent[i][a];
                }
            }
            return a;
        } else { // lca ~ b 에서 탐색
            int depthK = depth[lca] + (k - (depth[a] - depth[lca])) - 1;
            for (int i = S-1; i >= 0 ; i--) {
                if (depthK <= depth[parent[i][b]]) {
                    b = parent[i][b];
                }
            }
            return b;
        }
    }
    static void dfs(int node, int count) {
        depth[node] = count;
        int len = tree[node].size();
        for (int i = 0; i < len; i++) {
            Edge edge = (Edge) tree[node].get(i);
            int next = edge.to;
            if(depth[next] == 0) {
                dist[next] = dist[node] + edge.dist;
                dfs(next, count + 1);
                parent[0][next] = node;
            }
        }
    }

    static int lca(int a, int b) {
        if(depth[a] > depth[b]) { // 항상 b가 더 깊도록
            int temp = a;
            a = b;
            b = temp;
        }

        for (int i = S-1; i >= 0; i--) {
            if(depth[a] <= depth[parent[i][b]]) {
                b = parent[i][b];
            }
        }

        if(a == b) return a;

        for (int i = S-1; i >= 0; i--) {
            if(parent[i][a] != parent[i][b]) {
                a = parent[i][a];
                b = parent[i][b];
            }
        }
        return parent[0][a];
    }
}
profile
공부하고 있어요!

0개의 댓글