https://www.acmicpc.net/problem/3176
N개의 정점과 N-1개의 도로로 이루어진 도로 네트워크
이 문제를 풀 정도면 트리인것과 LCA를 이용해서 풀어야 한다는 것쯤은 쉽게 눈치챌수 있을텐데
도로중에서의 최솟값과 최댓값을 어떻게 구할것인지가 문제이다.
나도 LCA를 이용해서 풀어야한다는것은 쉽게 알아챘는데 저 최솟값과 최댓값 구하는데에서 애먹었다.
문제풀이는 기본 LCA를 NlogN으로 푸는 방식에다가
따로 최대값과 최소값을 정의 할 배열을 추가 해야 한다.
for j in range(1, size): for i in range(1, n + 1): parent[i][j] = parent[parent[i][j - 1]][j - 1] parent_max[i][j] = max(parent_max[i][j - 1], parent_max[parent[i][j - 1]][j - 1]) parent_min[i][j] = min(parent_min[i][j - 1], parent_min[parent[i][j - 1]][j - 1])```
위 코드에서 2^j의 간격으로 부모를 찾아내게 되고 최소값과 최대값도 비슷한 방식으로 찾으면 된다.
parent_max와 parent_min은 node i 에서 2^j위에 있는 부모까지의 edge중 최대값과 최소값이다.
매 j 루프마다 노드 i의 j-1까지의 거리를 구할수가 있다.
j == 1 인경우 2^1 =2
노드 i에서 2^0의 부모까지의 거리(1) , 부모 노드에서 다시 2^0 까지의 거리(1)를 가지고 있다.
단순 최대 최소 값만 가지고 있으면 되므로 이 두 거리를 비교하면 전체 거리를 비교한값과 같다!
j == 2인경우
노드 i에서 i의 2^1 부모까지의 거리 , i의2^1번째 부모에서 2^1 까지의 거리는 i에서 2^j까지의 거리와 같다.
from collections import defaultdict
from math import log2
import sys
input = sys.stdin.readline
sys.setrecursionlimit(110000)
n = int(input())
size = int(log2(n) + 1)
graph = [[] for _ in range(n + 1)]
for i in range(n - 1):
a, b, c = map(int, input().split())
graph[a].append([b, c])
graph[b].append([a, c])
visit = [False] * (n + 1)
depth = [0] * (n + 1)
parent = [[0] * size for _ in range(n + 1)]
parent_max = [[-float('inf')] * size for _ in range(n + 1)]
parent_min = [[float('inf')] * size for _ in range(n + 1)]
graph_max = [defaultdict(int) for _ in range(n + 1)]
def dfs(node):
visit[node] = True
for next, cost in graph[node]:
if not visit[next]:
depth[next] = depth[node] + 1
dfs(next)
parent[next][0] = node
parent_max[next][0] = cost
parent_min[next][0] = cost
dfs(1)
for j in range(1, size):
for i in range(1, n + 1):
parent[i][j] = parent[parent[i][j - 1]][j - 1]
parent_max[i][j] = max(parent_max[i][j - 1], parent_max[parent[i][j - 1]][j - 1])
parent_min[i][j] = min(parent_min[i][j - 1], parent_min[parent[i][j - 1]][j - 1])
def query(a, b):
depth_a = depth[a]
depth_b = depth[b]
answer_max = -float('inf')
answer_min = float('inf')
if depth_a > depth_b:
depth_diff = depth_a - depth_b
while depth_diff:
bit = int(log2(depth_diff))
answer_max = max(parent_max[a][bit], answer_max)
answer_min = min(parent_min[a][bit], answer_min)
a = parent[a][bit]
depth_diff -= 1 << bit
elif depth_a < depth_b:
depth_diff = depth_b - depth_a
while (depth_diff):
bit = int(log2(depth_diff))
answer_max = max(parent_max[b][bit], answer_max)
answer_min = min(parent_min[b][bit], answer_min)
b = parent[b][bit]
depth_diff -= 1 << bit
if a == b:
return answer_min, answer_max
for i in range(size - 1, -1, -1):
if parent[a][i] and parent[b][i]:
if parent[a][i] != parent[b][i]:
answer_max = max(answer_max, parent_max[a][i], parent_max[b][i])
answer_min = min(answer_min, parent_min[a][i], parent_min[b][i])
a = parent[a][i]
b = parent[b][i]
answer_max = max(answer_max, parent_max[a][0], parent_max[b][0])
answer_min = min(answer_min, parent_min[a][0], parent_min[b][0])
return answer_min, answer_max
k = int(input())
for i in range(k):
a, b = map(int, input().split())
print(*query(a, b))