[#1967] 트리의 지름

RookieAND·2022년 6월 15일
0

BaekJoon

목록 보기
13/42
post-thumbnail

❓ Question

https://www.acmicpc.net/problem/1967

📖 Before Start

트리 자료구조를 학습할 겸, 더 어려운 문제 를 풀어보기로 하였다.

이번 문제는 무방향 그래프이자, 사이클이 없는 트리의 특성을 이용한 지름 찾기 알고리즘이었다.
처음에는 트리에 지름이 있다고? 라는 의문이 들었지만 문제 설명을 자세히 보니 이해가 갔다.

✒️ Design Algorithm

루트 노드 에서 가장 거리가 먼 단말 노드 두 개를 찾고, 두 지점 간의 거리를 구하자.

처음엔 루트 노드에서 가장 거리가 먼 단말 노드를 두 개 찾고, 두 노드 간 거리를 구하면 된다 생각했다.
그러기 위해서는 먼저 두 개의 단말 노드가 루트 노드로부터 얼마나 멀리 떨어졌는지를 알아내야 했다.
따라서 처음 값을 입력 받았을 때, 간선의 가중치를 합산하여 루트 노드와 떨어진 거리로 재설정하였다.


루트 노드에서 가장 거리가 먼 노드들을 두 개 선별해보자.

  1. 루트 노드로부터 각각의 노드가 얼만큼 떨어졌는지를 구하고, 이를 하나의 List에 저장한다.
  2. 그 후 List 를 정렬하여 가장 거리가 먼 값을 구하고, 각각의 인덱스를 구해 노드를 구한다.
  3. 루트 노드로부터 두 노드 간의 거리를 합산하고, 서로 겹치는 거리만큼을 구하여 감산시킨다.

💻 Making Own Code

❌ Wrong Code

import sys

read = sys.stdin.readline
N = int(read())
tree = [list() for _ in range(N + 1)]
distance = [0] * (N + 1)

# distance에는 루트 노드로부터 해당 노드까지의 가중치를 저장.
# 각 간선의 가중치와 부모 - 자식 노드를 연결시켜 tree에 추가함.
for _ in range(N-1):
    p_node, c_node, dist = map(int, read().split())
    tree[p_node].append(c_node)
    distance[c_node] = distance[p_node] + dist

# 루트 노드로부터 가장 긴 단말 노드 2개를 선택한다.
f_len, s_len = sorted(distance, reverse=True)[:2]
f_node, s_node = (distance.index(f_len), distance.index(s_len))

# 루트 노드와 두 단말 노드 간의 겹치는 거리를 계산하여 제외해야 한다.
# 따라서 두 노드의 공통된 부모 노드를 찾는 함수를 선언한다.
def find_parent(f_nd, s_nd):
    while f_nd != s_nd:
        for i in range(N):
            if f_nd in tree[i]:
                f_nd = i
            if s_nd in tree[i]:
                s_nd = i
    return f_nd

p_node = find_parent(f_node, s_node)
result = f_len + s_len - (distance[p_node] * 2)
print(result)

풀이는 그럴싸 했으나, 결론적으로 완전히 잘못된 접근 이었다.

루트 노드로부터 가장 멀리 떨어진 두 노드를 선별한 것은 좋았다. 여기까지는 괜찮았다.
하지만 가장 멀리 떨어진 두 노드가 트리의 지름 을 이루는 두 접점이라는 보장이 없었다.

따라서 이 문제는 올바른 접근법을 알아내기 위해 결국 외부 자료의 힘을 빌릴 수밖에 없었다.
트리의 지름 을 구하기 위해서는 아래의 방법대로 알고리즘을 설계해야 한다.

  1. 특정 노드로부터 가장 멀리 떨어진 노드 A 를 구한다.
  2. 노드 A 로부터 가장 멀리 떨어진 노드 B 를 구한다.
  3. 노드 A 와 노드 B 간의 거리가 바로 트리의 지름 이 된다.

증명 공식은 https://bedamino.tistory.com/15 를 참고하였다. 아주 큰 도움이 되었다 (...)

✅ Correct Code

import sys
from collections import deque

read = sys.stdin.readline
N = int(read())
tree = [list() for _ in range(N + 1)]

# 각 간선의 가중치와 부모 - 자식 노드를 연결시켜 tree에 추가함.
for _ in range(N-1):
    p_node, c_node, dist = map(int, read().split())
    tree[p_node].append((c_node, dist))
    tree[c_node].append((p_node, dist))


# 루트 노드로부터 가장 거리가 먼 노드와, 거리를 구하는 함수를 작성.
# 해당 노드와 이어진 노드들의 목록을 덱에 추가하여 순회
def bfs(node):
    farest_node, distance = 0, 0
    queue = deque(tree[node])
    visited = [False] * (N + 1)
    visited[node] = True
    while queue:
        nd, dist = queue.popleft()
        visited[nd] = True
        # 만약 더 거리가 먼 노드를 찾았다면, 이를 업데이트 해야 함.
        if dist > distance:
            farest_node = nd
            distance = dist
        # 해당 노드까지의 거리에서 추가 가중치만큼을 더하여 덱에 추가.
        for new_nd, new_dist in tree[nd]:
            if not visited[new_nd]:
                queue.append((new_nd, dist + new_dist))
    return farest_node, distance

# 루트 노드로부터 가장 먼 노드만을 선별하여 구한다.
far_node, _ = bfs(1)
# 그 후, 가장 먼 노드로부터 가장 먼 노드와의 거리를 구한다.
_, radius = bfs(far_node)
print(radius)

bfs 함수를 설계할 때, 특정 노드를 입력 받으면 해당 노드로부터 가장 거리가 먼 노드를 반환시키게 했다.
추가로, 가장 거리가 먼 노드와의 거리 같이 저장하여 tuple 형식으로 리턴 값을 주도록 함수를 작성했다.

그 후로는 루트 노드 에서 가장 멀리 떨어진 노드인 far_node 만을 선별하여 구하고,
이를 토대로 다시 한번 BFS 탐색을 진행하여 트리의 지름인 Radius 값을 찾아내였다.

해당 알고리즘을 토대로 코드를 재작성하니, 마침내 정답 처리 를 해주셨다.

원리를 제대로 파악하지 못하면 풀기가 어려운 문제였다. 이제 확실히 알았으니 다음부터는 틀리지 말자.

📖 Conclusion

https://github.com/RookieAND/BaekJoonCode/tree/main/%EB%B0%B1%EC%A4%80/Gold/15686.%E2%80%85%EC%B9%98%ED%82%A8%E2%80%85%EB%B0%B0%EB%8B%AC

요 근래 몸 상태가 부쩍 좋지 않아 공부에 통 집중을 못하고 있지만, 그래도 꼬박꼬박 풀어야 한다!

profile
항상 왜 이걸 써야하는지가 궁금한 사람

0개의 댓글