트리에서 두 정점의 최소 공통 조상을 구하는 자료구조를 배워 봅시다.
LCA를 효율적으로 구해 봅시다.
이번 문제는 두 노드의 쌍 M(1 ≤ M ≤ 100,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력하는 문제이다.
- dfs를 통해서 자신의 부모를 구한다.
- 이중 for문을 통해서 조상을 세팅한다.
- parent[j][i-1] = j의 2^i-1만큼 위에 있는 조상
- 두 노드의 가장 가까운 조상(부모) 출력 (LCA알고리즘)
- swap (a가 더 얕으면 swap)
- 두 노드 높이를 맞춤
- 같은 높이일 경우, 두 노드의 값이 같으면 종료
- 같은 높이일 경우, 두 노드의 값이 다르면 두 노드의 높이를 올려가며 확인
Java / Python
import java.io.*;
import java.util.*;
public class Main {
private static int N, M, K;
private static int[] depth;
private static int[][] parent;
private static ArrayList<Integer>[] list;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st;
StringBuilder sb = new StringBuilder();
N = Integer.parseInt(br.readLine());
list = new ArrayList[N + 1];
for (int i = 0; i < N + 1; i++) {
list[i] = new ArrayList<>();
}
for (int i = 0; i < N - 1; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
list[a].add(b);
list[b].add(a);
}
K = 0;
int tmp = 1;
while (tmp < N + 1) {
tmp <<= 1;
K++;
}
depth = new int[N + 1];
parent = new int[N + 1][K];
dfs(1, 1);
for (int i = 1; i < K; i++) {
// 2^K 번째 조상 노드 저장
for (int j = 1; j <= N; j++) {
parent[j][i] = parent[parent[j][i - 1]][i - 1];
}
}
M = Integer.parseInt(br.readLine());
for (int i = 0; i < M; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
sb.append(LCA(a, b)).append("\n");
}
bw.write(sb.toString() + "\n");
bw.flush();
bw.close();
br.close();
}
private static void dfs(int node, int cur) {
depth[node] = cur;
for (Integer next : list[node]) {
if (depth[next] == 0) {
dfs(next, cur + 1);
parent[next][0] = node;
}
}
}
private static int LCA(int a, int b) {
if (depth[a] < depth[b]) {
// a가 더 얕으면 swap
int temp = a;
a = b;
b = temp;
}
for (int i = K - 1; i >= 0; i--) {
if (Math.pow(2, i) <= depth[a] - depth[b]) {
// 높이 차이 만큼 a 높이 올리기
a = parent[a][i];
}
}
if (a == b)
return a;
for (int i = K - 1; i >= 0; i--) {
if (parent[a][i] != parent[b][i]) {
a = parent[a][i];
b = parent[b][i];
}
}
return parent[a][0];
}
}
import sys
sys.setrecursionlimit(100000)
input = sys.stdin.readline
LOG = 21
# 2^i 단위의 부모값을 저장하기 위한 크기
N = int(input())
parent = [[0] * LOG for _ in range(N + 1)]
visit = [False] * (N + 1)
depth = [0] * (N + 1)
tree = [[] for _ in range(N + 1)]
for _ in range(N - 1):
a, b = map(int, input().split())
tree[a].append(b)
tree[b].append(a)
def dfs(cur, d):
visit[cur] = True
depth[cur] = d
for node in tree[cur]:
if not visit[node]:
parent[node][0] = cur
dfs(node, d + 1)
def lca(a, b):
if depth[a] > depth[b]:
a, b = b, a
# a와 b의 깊이 동일하게
for i in range(LOG - 1, -1, -1):
if depth[b] - depth[a] >= (1<<i):
b = parent[b][i]
if a == b:
return a
# 올라가면서 공통 조상 찾기
for i in range(LOG - 1, -1, -1):
if parent[a][i] != parent[b][i]:
a = parent[a][i]
b = parent[b][i]
return parent[a][0]
dfs(1, 0)
# 모든 노드의 전체 부모 관계 갱신
for i in range(1, LOG):
for j in range(1, N + 1):
# 각 노드에 대해 2^i번째 부모 정보 갱신
parent[j][i] = parent[parent[j][i - 1]][i - 1]
M = int(input())
for _ in range(M):
a, b = map(int, input().split())
print(lca(a, b))