트리에서 두 정점의 최소 공통 조상을 구하는 자료구조를 배워 봅시다.
트리 상의 경로에서 최솟값과 최댓값을 찾는 문제
이번 문제는 모든 도시의 쌍에는 그 도시를 연결하는 유일한 경로가 있고, 각 도로의 길이는 입력으로 주어지고 총 K개의 도시 쌍이 주어질 때, 두 도시를 연결하는 경로 상에서 가장 짧은 도로의 길이와 가장 긴 도로의 길이를 구하는 문제이다.
Java / Python
dfs와 LCA를 이용해서, 공통 조상을 구하며, 최소 거리와 최대 거리를 계산한다.
import java.io.*;
import java.util.*;
public class Main {
static int N, K, k; // N : 정점수, k : 쿼리 수, k: 2의 지수
static int[] depth;
static int[][] parent; // parent[j][i] = parent[parent[j][i - 1]][i - 1];
static ArrayList<Node>[] tree;
// 도로 네트워크 변수
// min(max)Dist[k][V] 정점 V의 2^K번째 조상까지의
static int[][] minDist; // 최소거리
static int[][] maxDist; // 최대거리
static int min, max;
static class Node {
int target, cost;
public Node(int target, int cost) {
this.target = target;
this.cost = cost;
}
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st;
StringBuilder sb = new StringBuilder();
N = Integer.parseInt(br.readLine());
k = 0;
for (int i = 1; i <= N; i *= 2) {
k++;
}
depth = new int[N + 1];
parent = new int[N + 1][k];
minDist = new int[N + 1][k];
maxDist = new int[N + 1][k];
tree = new ArrayList[N + 1];
for (int i = 1; i < N + 1; i++) {
tree[i] = new ArrayList<Node>();
}
int a, b, c;
for (int i = 1; i <= N - 1; i++) {
st = new StringTokenizer(br.readLine());
a = Integer.parseInt(st.nextToken());
b = Integer.parseInt(st.nextToken());
c = Integer.parseInt(st.nextToken());
tree[a].add(new Node(b, c));
tree[b].add(new Node(a, c));
}
DFS(1, 1);
// parent 채우기
for (int i = 1; i < k; i++) {
for (int j = 1; j <= N; j++) {
parent[j][i] = parent[parent[j][i - 1]][i - 1];
minDist[j][i] = Math.min(minDist[j][i - 1], minDist[parent[j][i - 1]][i - 1]);
maxDist[j][i] = Math.max(maxDist[j][i - 1], maxDist[parent[j][i - 1]][i - 1]);
}
}
// LCA
K = Integer.parseInt(br.readLine());
for (int i = 1; i <= K; i++) {
st = new StringTokenizer(br.readLine());
a = Integer.parseInt(st.nextToken());
b = Integer.parseInt(st.nextToken());
LCA(a, b);
sb.append(min + " " + max + "\n");
}
bw.write(sb.toString());
bw.flush();
bw.close();
br.close();
}
// depth 확인
static void DFS(int node, int cur) {
depth[node] = cur;
for (Node next : tree[node]) {
if (depth[next.target] == 0) {
DFS(next.target, cur + 1);
parent[next.target][0] = node;
// 현재 cost로 갱신
minDist[next.target][0] = next.cost;
maxDist[next.target][0] = next.cost;
}
}
return;
}
static void LCA(int a, int b) {
if (depth[a] < depth[b]) {
// a가 더 얕으면 swap
int temp = a;
a = b;
b = temp;
}
min = Integer.MAX_VALUE;
max = -1;
for (int i = k - 1; i >= 0; i--) {
if (Math.pow(2, i) <= depth[a] - depth[b]) {
min = Math.min(min, minDist[a][i]);
max = Math.max(max, maxDist[a][i]);
a = parent[a][i]; // 높이 차이 만큼 a 높이 올리기
}
}
if (a == b)
return;
for (int i = k - 1; i >= 0; i--) {
if (parent[a][i] != parent[b][i]) {
min = Math.min(min, Math.min(minDist[a][i], minDist[b][i]));
max = Math.max(max, Math.max(maxDist[a][i], maxDist[b][i]));
a = parent[a][i];
b = parent[b][i];
}
}
min = Math.min(min, Math.min(minDist[a][0], minDist[b][0]));
max = Math.max(max, Math.max(maxDist[a][0], maxDist[b][0]));
return;
}
}
dp를 이용해 (3차원 배열 이용 [0]: parent, [1]: minLength, [2]: maxLength) 초기화 및 희소 테이블을 계산해주고, 두 도시 사이의 거리를 구할 수 있다.
import sys
import math
from collections import deque
input = sys.stdin.readline
N = int(input())
K = int(math.log2(N))+1
tree = [[] for _ in range(N + 1)]
for i in range(N - 1):
a, b, w = map(int, input().split())
tree[a].append((b,w))
tree[b].append((a,w))
queue = deque([(1, 1)])
depth = [0] * (N + 1)
depth[1] = 1
dp = [[[0,0,0] for _ in range(K)] for _ in range(N+1)]
while queue:
i, d = queue.popleft()
for j, w in tree[i]:
if not depth[j]:
queue.append((j, d + 1))
depth[j] = d + 1
dp[j][0] = [i,w,w]
for j in range(1, K):
for i in range(1, N + 1):
dp[i][j][0] = dp[dp[i][j-1][0]][j-1][0]
dp[i][j][1] = min(dp[i][j-1][1], dp[dp[i][j-1][0]][j-1][1])
dp[i][j][2] = max(dp[i][j-1][2], dp[dp[i][j-1][0]][j-1][2])
for _ in range(int(input())):
a, b = map(int, input().split())
max_len = 0
min_len = float('inf')
if depth[a] > depth[b]:
a, b = b, a
for i in range(K, -1, -1):
if depth[b] - depth[a] >= (1<<i):
min_len = min(dp[b][i][1], min_len)
max_len = max(dp[b][i][2], max_len)
b = dp[b][i][0]
if a == b:
print(min_len, max_len)
continue
for i in range(K-1, -1, -1):
if dp[a][i][0] != dp[b][i][0]:
min_len = min(dp[a][i][1],dp[b][i][1], min_len)
max_len = max(dp[a][i][2],dp[b][i][2], max_len)
a = dp[a][i][0]
b = dp[b][i][0]
min_len = min(dp[a][0][1],dp[b][0][1], min_len)
max_len = max(dp[a][0][2],dp[b][0][2], max_len)
print(min_len, max_len)