N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다.
아래의 두 쿼리를 수행하는 프로그램을 작성하시오.
1 u v: u에서 v로 가는 경로의 비용을 출력한다.
2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.
첫째 줄에 N (2 ≤ N ≤ 100,000)이 주어진다.
둘째 줄부터 N-1개의 줄에는 i번 간선이 연결하는 두 정점 번호 u와 v와 간선의 비용 w가 주어진다.
다음 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.
다음 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다.
간선의 비용은 항상 1,000,000보다 작거나 같은 자연수이다.
각각의 쿼리의 결과를 순서대로 한 줄에 하나씩 출력한다.
두개의 동작을 하는 코드를 짜야한다.
1. u -> v 의 경로 비용 구하기
2. u -> v 의 경로 중 k 번째 노드 찾기
1을 위해 dfs를 진행하며 루트 노드로부터의 비용을 저장한 dist[] 배열을 사용한다.
두 노드간 거리는 dist[u] + dist[v] - 2 * dist[lca]가 된다.
이건 직접 그려보면 더 이해가 빠르다.
2에서 lca를 기준으로 k번째가 왼쪽에 있는지 오른쪽에 있는지 판단하고 lca 알고리즘을 수행했던 것처럼 top-down 방식으로 탐색해준다.
lca 자체가 어려운 알고리즘이라 계속해서 복기하는 게 필요할 것 같다.
import java.io.*;
import java.util.*;
public class Main {
static List[] tree;
static int[][] parent; // parent[i][j] : j의 2^i 번째 부모
static int[] depth;
static long[] dist;
static int N, S, M;
static class Edge {
int to;
long dist;
public Edge(int to, long dist) {
this.to = to;
this.dist = dist;
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringBuilder sb = new StringBuilder();
N = Integer.parseInt(br.readLine());
S = 0;
for (int i = 1; i <= N; i *= 2) {
S++;
}
parent = new int[S][N + 1];
tree = new List[N + 1];
depth = new int[N + 1];
dist = new long[N + 1];
for (int i = 1; i <= N; i++) {
tree[i] = new ArrayList<Edge>();
}
StringTokenizer st;
for (int i = 0; i < N - 1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
long w = Long.parseLong(st.nextToken());
tree[u].add(new Edge(v, w));
tree[v].add(new Edge(u, w));
}
dfs(1, 1);
for (int i = 1; i < S; i++) {
for (int j = 1; j <= N; j++) { // 2^i 번째 부모는 2^i-1번째 부모의 2^i-1번째 부모
parent[i][j] = parent[i - 1][parent[i - 1][j]];
}
}
M = Integer.parseInt(br.readLine());
for (int i = 0; i < M; i++) {
st = new StringTokenizer(br.readLine());
int type = Integer.parseInt(st.nextToken());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
int lca = lca(a, b);
if(type == 1) {
sb.append(dist[a] + dist[b] - 2 * dist[lca]).append("\n");
} else {
int k = Integer.parseInt(st.nextToken());
sb.append(getKth(a, b, lca, k)).append("\n");
}
}
bw.write(sb.toString());
bw.flush();
bw.close();
}
static int getKth(int a, int b, int lca, int k) {
if(k == depth[a] - depth[lca] + 1) return lca;
if(k < depth[a] - depth[lca] + 1) { // a ~ lca에서 탐색
int depthK = depth[a] - k + 1;
for (int i = S - 1; i >= 0; i--) {
if (depthK <= depth[parent[i][a]]) {
a = parent[i][a];
}
}
return a;
} else { // lca ~ b 에서 탐색
int depthK = depth[lca] + (k - (depth[a] - depth[lca])) - 1;
for (int i = S-1; i >= 0 ; i--) {
if (depthK <= depth[parent[i][b]]) {
b = parent[i][b];
}
}
return b;
}
}
static void dfs(int node, int count) {
depth[node] = count;
int len = tree[node].size();
for (int i = 0; i < len; i++) {
Edge edge = (Edge) tree[node].get(i);
int next = edge.to;
if(depth[next] == 0) {
dist[next] = dist[node] + edge.dist;
dfs(next, count + 1);
parent[0][next] = node;
}
}
}
static int lca(int a, int b) {
if(depth[a] > depth[b]) { // 항상 b가 더 깊도록
int temp = a;
a = b;
b = temp;
}
for (int i = S-1; i >= 0; i--) {
if(depth[a] <= depth[parent[i][b]]) {
b = parent[i][b];
}
}
if(a == b) return a;
for (int i = S-1; i >= 0; i--) {
if(parent[i][a] != parent[i][b]) {
a = parent[i][a];
b = parent[i][b];
}
}
return parent[0][a];
}
}