BOJ 1761 정점들의 거리

LONGNEW·2021년 11월 9일
0

BOJ

목록 보기
275/333

https://www.acmicpc.net/problem/1761
시간 2초, 메모리 128MB

input :

  • N(2 ≤ N ≤ 40,000)
  • N - 1개의 줄 : 트리 상에 연결된 두 점과 거리
  • M(1 ≤ M ≤ 10,000)
  • M개의 줄 : 한 쌍씩 입력

output :

  • 두 노드 사이의 거리를 출력

조건 :

  • N개의 정점으로 이루어진 트리

  • M개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력


개선된 LCA를 사용해서 해결하려 하였지만 parent에 거리까지 추가해서 3차원 배열을 사용하는 것이 생각보다 많은 메모리를 차지 했다.

40000 * log(40000) * 2를 하면 모든 경우가 되지 않나 했지만 잘못 된 계산이였던것 같다.

암튼 그냥 LCA를 사용해서 각각 자신의 부모를 찾는 방식으로 문제를 해결할 수 있다.
이 때 자신들의 깊이가 다를 때, 부모가 다를 때 자기 자신의 부모로 업데이트 되는 경우
거리를 합해줘야만 문제에서 원하는 해답을 구할 수 있다.

import sys
sys.setrecursionlimit(100000)

def dfs(node, deep):
    visit[node] = 1
    depth[node] = deep

    for next_node, cost in graph[node]:
        if visit[next_node] == 1:
            continue
        parent[next_node] = [node, cost]
        dfs(next_node, deep + 1)

def lca(a, b):
    ans = 0

    # 언제나 b가 더 깊은 곳에 있도록 만듬
    if depth[a] > depth[b]:
        a, b = b, a

    while depth[a] != depth[b]:
        ans += parent[b][1]
        b = parent[b][0]

    if a == b:
        return ans

    while parent[a][0] != parent[b][0]:
        ans += parent[b][1]
        ans += parent[a][1]
        b = parent[b][0]
        a = parent[a][0]

    ans += parent[a][1] + parent[b][1]
    return ans

n = int(sys.stdin.readline())
graph = [[] for _ in range(n + 1)]
visit, depth = [0] * (n + 1), [0] * (n + 1)
parent = [[0, 0] for _ in range(n + 1)]
root = [i for i in range(n + 1)]

for _ in range(n - 1):
    a, b, cost = map(int, sys.stdin.readline().split())
    graph[a].append((b, cost))
    graph[b].append((a, cost))

    # root 노드를 찾아서 모든 노드의 깊이를 구하기 위한 과정
    if root[a] > root[b]:
        root[a] = root[b]
    else:
        root[b] = root[a]

dfs(root[1], 0)

m = int(sys.stdin.readline())
for _ in range(m):
    u, v = map(int, sys.stdin.readline().split())
    print(lca(u, v))

0개의 댓글