https://www.acmicpc.net/problem/1761
시간 2초, 메모리 128MB
input :
output :
조건 :
N개의 정점으로 이루어진 트리
M개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력
개선된 LCA를 사용해서 해결하려 하였지만 parent에 거리까지 추가해서 3차원 배열을 사용하는 것이 생각보다 많은 메모리를 차지 했다.
40000 * log(40000) * 2를 하면 모든 경우가 되지 않나 했지만 잘못 된 계산이였던것 같다.
암튼 그냥 LCA를 사용해서 각각 자신의 부모를 찾는 방식으로 문제를 해결할 수 있다.
이 때 자신들의 깊이가 다를 때, 부모가 다를 때 자기 자신의 부모로 업데이트 되는 경우
거리를 합해줘야만 문제에서 원하는 해답을 구할 수 있다.
import sys
sys.setrecursionlimit(100000)
def dfs(node, deep):
visit[node] = 1
depth[node] = deep
for next_node, cost in graph[node]:
if visit[next_node] == 1:
continue
parent[next_node] = [node, cost]
dfs(next_node, deep + 1)
def lca(a, b):
ans = 0
# 언제나 b가 더 깊은 곳에 있도록 만듬
if depth[a] > depth[b]:
a, b = b, a
while depth[a] != depth[b]:
ans += parent[b][1]
b = parent[b][0]
if a == b:
return ans
while parent[a][0] != parent[b][0]:
ans += parent[b][1]
ans += parent[a][1]
b = parent[b][0]
a = parent[a][0]
ans += parent[a][1] + parent[b][1]
return ans
n = int(sys.stdin.readline())
graph = [[] for _ in range(n + 1)]
visit, depth = [0] * (n + 1), [0] * (n + 1)
parent = [[0, 0] for _ in range(n + 1)]
root = [i for i in range(n + 1)]
for _ in range(n - 1):
a, b, cost = map(int, sys.stdin.readline().split())
graph[a].append((b, cost))
graph[b].append((a, cost))
# root 노드를 찾아서 모든 노드의 깊이를 구하기 위한 과정
if root[a] > root[b]:
root[a] = root[b]
else:
root[b] = root[a]
dfs(root[1], 0)
m = int(sys.stdin.readline())
for _ in range(m):
u, v = map(int, sys.stdin.readline().split())
print(lca(u, v))