트리에서 두 정점의 최소 공통 조상을 구하는 자료구조를 배워 봅시다.
트리 상의 경로에서 k번째 정점을 구하는 문제
이번 문제는 N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다. 두 쿼리를 수행하는 프로그램을 작성하는 문제이다.
1 u v: u에서 v로 가는 경로의 비용을 출력한다.
2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.
dfs와 LCA를 이용해서, 이전 LCA 문제와 비슷한 방식으로 구할 수 있다. 조건을 나누어 거리(간선 비용)를 구한다.
Java : LCA함수와 parent를 이용해 조상을 구하고, dfs를 통해 depth를 확인하고 dist를 구한다.
Python : parent를 이용해 각 노드의 부모 노드 및 depth 계산하고, dp를 이용해 희소 테이블을 초기화 및 희소 테이블을 계산해주고, 조건을 나누어 거리를 구할 수 있다.
Java / Python
import java.io.*;
import java.util.*;
public class Main {
static int N, M; // N : 정점수, M : 쿼리 수
static int[] depth;
static long[] dist;
static int[][] parent; // parent[j][i] = parent[parent[j][i - 1]][i - 1];
static ArrayList<Node>[] tree;
static class Node {
int target, cost;
public Node(int target, int cost) {
this.target = target;
this.cost = cost;
}
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st;
StringBuilder sb = new StringBuilder();
N = Integer.parseInt(br.readLine());
depth = new int[N + 1];
dist = new long[N + 1];
parent = new int[N + 1][18];
tree = new ArrayList[N + 1];
for (int i = 1; i < N + 1; i++) {
tree[i] = new ArrayList<Node>();
}
for (int i = 0; i < N - 1; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
int c = Integer.parseInt(st.nextToken());
tree[a].add(new Node(b, c));
tree[b].add(new Node(a, c));
}
DFS(1, 1);
// parent 채우기
for (int i = 1; i < 18; i++) {
for (int j = 2; j <= N; j++) {
parent[j][i] = parent[parent[j][i - 1]][i - 1];
}
}
// LCA
M = Integer.parseInt(br.readLine());
for (int i = 0; i < M; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
int root = LCA(u, v);
if (a == 1) {
sb.append(dist[u] + dist[v] - 2 * dist[root] + "\n");
} else {
int k = Integer.parseInt(st.nextToken());
int cnt = depth[u] - depth[root] + 1;
if (cnt == k)
sb.append(root + "\n");
else if (cnt > k) {
k--;
int tmp = u;
for (int j = 0; j < 18; j++) {
if ((k & 1 << j) != 0) {
k -= 1 << j;
tmp = parent[tmp][j];
}
}
sb.append(tmp + "\n");
} else {
k = cnt + depth[v] - depth[root] - k + 1;
k--;
int tmp = v;
for (int j = 0; j < 18; j++) {
if ((k & 1 << j) != 0) {
k -= 1 << j;
tmp = parent[tmp][j];
}
}
sb.append(tmp + "\n");
}
}
}
bw.write(sb.toString());
bw.flush();
bw.close();
br.close();
}
// depth 확인
// dist 추가
static void DFS(int node, int cur) {
depth[node] = cur;
for (Node next : tree[node]) {
if (depth[next.target] == 0) {
parent[next.target][0] = node;
dist[next.target] = dist[node] + next.cost;
DFS(next.target, cur + 1);
}
}
return;
}
static int LCA(int a, int b) {
if (depth[a] < depth[b]) {
// a가 더 얕으면 swap
int temp = a;
a = b;
b = temp;
}
for (int i = 18; i >= 0; i--) {
if (Math.pow(2, i) <= depth[a] - depth[b]) {
a = parent[a][i]; // 높이 차이 만큼 a 높이 올리기
}
}
if (a == b) return a;
for (int i = 17; i >= 0; i--) {
if (parent[a][i] != parent[b][i]) {
a = parent[a][i];
b = parent[b][i];
}
}
return parent[a][0];
}
}
import sys
from collections import deque
from math import log2
input = sys.stdin.readline
# tree 입력, 정렬, 부모노드, depth 계산 부분
N = int(input())
tree = [[] for _ in range(N + 1)]
for _ in range(N - 1):
a, b, c = map(int, input().split())
tree[a].append([b, c])
tree[b].append([a, c])
parent = [[0, 0] for _ in range(N + 1)]
depth = [0] * (N + 1)
visit = [False] * (N + 1)
que = deque([1])
visit[1] = True
while que:
now = que.popleft()
for b, c in tree[now]:
if not visit[b]:
que.append(b)
parent[b][0] = now
parent[b][1] = c
depth[b] = depth[now] + 1
visit[b] = True
# 희소 테이블
logN = int(log2(N) + 1)
dp = [[[0, 0] for _ in range(logN)] for __ in range(N + 1)]
for i in range(N + 1):
dp[i][0][0] = parent[i][0]
dp[i][0][1] = parent[i][1]
for j in range(1, logN):
for i in range(1, N + 1):
dp[i][j][0] = dp[dp[i][j - 1][0]][j - 1][0]
dp[i][j][1] = dp[i][j - 1][1] + dp[dp[i][j - 1][0]][j - 1][1]
# 쿼리문 입력, 처리
M = int(sys.stdin.readline())
for _ in range(M):
Query = list(map(int, input().split()))
u, v = Query[1], Query[2]
u2, v2 = u, v
# 공통 조상 탐색
if depth[u2] < depth[v2]:
u2, v2 = v2, u2
diff = depth[u2] - depth[v2]
for i in range(logN):
if diff & 1 << i:
u2 = dp[u2][i][0]
if u2 == v2:
lca = u2
else:
for i in range(logN - 1, -1, -1):
if dp[u2][i][0] != dp[v2][i][0]:
u2 = dp[u2][i][0]
v2 = dp[v2][i][0]
lca = dp[u2][0][0]
if Query[0] == 1:
cost = 0
diff_u = depth[u] - depth[lca]
diff_v = depth[v] - depth[lca]
for i in range(logN):
if diff_u & 1 << i:
cost += dp[u][i][1]
u = dp[u][i][0]
if diff_v & 1 << i:
cost += dp[v][i][1]
v = dp[v][i][0]
print(cost)
else:
k = Query[3]
# u 의 k - 1 조상을 계산
if k <= depth[u] - depth[lca]:
diff = k - 1
for i in range(logN):
if diff & 1 << i:
u = dp[u][i][0]
print(u)
else: # 남은 거리를 v부터 계산
diff = depth[v] + depth[u] - 2 * depth[lca] - k + 1
for i in range(logN - 1, -1, -1):
if diff & 1 << i:
v = dp[v][i][0]
print(v)