[Python][Platinum 4] BOJ 3176 도로 네트워크

태규 최·2022년 3월 13일
0

Algorithm

목록 보기
7/8

https://www.acmicpc.net/problem/3176

N개의 정점과 N-1개의 도로로 이루어진 도로 네트워크

이 문제를 풀 정도면 트리인것과 LCA를 이용해서 풀어야 한다는 것쯤은 쉽게 눈치챌수 있을텐데

도로중에서의 최솟값과 최댓값을 어떻게 구할것인지가 문제이다.

나도 LCA를 이용해서 풀어야한다는것은 쉽게 알아챘는데 저 최솟값과 최댓값 구하는데에서 애먹었다.

문제풀이는 기본 LCA를 NlogN으로 푸는 방식에다가

따로 최대값과 최소값을 정의 할 배열을 추가 해야 한다.

for j in range(1, size):
    for i in range(1, n + 1):
        parent[i][j] = parent[parent[i][j - 1]][j - 1]
        parent_max[i][j] = max(parent_max[i][j - 1], parent_max[parent[i][j - 1]][j - 1])
        parent_min[i][j] = min(parent_min[i][j - 1], parent_min[parent[i][j - 1]][j - 1])``` 

위 코드에서 2^j의 간격으로 부모를 찾아내게 되고 최소값과 최대값도 비슷한 방식으로 찾으면 된다.

parent_max와 parent_min은 node i 에서 2^j위에 있는 부모까지의 edge중 최대값과 최소값이다.

매 j 루프마다 노드 i의 j-1까지의 거리를 구할수가 있다.

j == 1 인경우 2^1 =2
노드 i에서 2^0의 부모까지의 거리(1) , 부모 노드에서 다시 2^0 까지의 거리(1)를 가지고 있다.
단순 최대 최소 값만 가지고 있으면 되므로 이 두 거리를 비교하면 전체 거리를 비교한값과 같다!

j == 2인경우

노드 i에서 i의 2^1 부모까지의 거리 , i의2^1번째 부모에서 2^1 까지의 거리는 i에서 2^j까지의 거리와 같다.

from collections import defaultdict
from math import log2
import sys

input = sys.stdin.readline
sys.setrecursionlimit(110000)
n = int(input())
size = int(log2(n) + 1)
graph = [[] for _ in range(n + 1)]
for i in range(n - 1):
    a, b, c = map(int, input().split())
    graph[a].append([b, c])
    graph[b].append([a, c])

visit = [False] * (n + 1)

depth = [0] * (n + 1)

parent = [[0] * size for _ in range(n + 1)]
parent_max = [[-float('inf')] * size for _ in range(n + 1)]
parent_min = [[float('inf')] * size for _ in range(n + 1)]

graph_max = [defaultdict(int) for _ in range(n + 1)]


def dfs(node):
    visit[node] = True

    for next, cost in graph[node]:
        if not visit[next]:
            depth[next] = depth[node] + 1
            dfs(next)
            parent[next][0] = node
            parent_max[next][0] = cost
            parent_min[next][0] = cost


dfs(1)
for j in range(1, size):
    for i in range(1, n + 1):
        parent[i][j] = parent[parent[i][j - 1]][j - 1]
        parent_max[i][j] = max(parent_max[i][j - 1], parent_max[parent[i][j - 1]][j - 1])
        parent_min[i][j] = min(parent_min[i][j - 1], parent_min[parent[i][j - 1]][j - 1])


def query(a, b):
    depth_a = depth[a]
    depth_b = depth[b]
    answer_max = -float('inf')
    answer_min = float('inf')
    if depth_a > depth_b:
        depth_diff = depth_a - depth_b
        while depth_diff:
            bit = int(log2(depth_diff))
            answer_max = max(parent_max[a][bit], answer_max)
            answer_min = min(parent_min[a][bit], answer_min)
            a = parent[a][bit]
            depth_diff -= 1 << bit
    elif depth_a < depth_b:
        depth_diff = depth_b - depth_a
        while (depth_diff):
            bit = int(log2(depth_diff))
            answer_max = max(parent_max[b][bit], answer_max)
            answer_min = min(parent_min[b][bit], answer_min)
            b = parent[b][bit]
            depth_diff -= 1 << bit
    if a == b:
        return answer_min, answer_max

    for i in range(size - 1, -1, -1):
        if parent[a][i] and parent[b][i]:
            if parent[a][i] != parent[b][i]:
                answer_max = max(answer_max, parent_max[a][i], parent_max[b][i])
                answer_min = min(answer_min, parent_min[a][i], parent_min[b][i])
                a = parent[a][i]
                b = parent[b][i]

    answer_max = max(answer_max, parent_max[a][0], parent_max[b][0])
    answer_min = min(answer_min, parent_min[a][0], parent_min[b][0])
    return answer_min, answer_max


k = int(input())
for i in range(k):
    a, b = map(int, input().split())
    print(*query(a, b))

0개의 댓글