트리 자료구조를 학습할 겸, 더 어려운 문제 를 풀어보기로 하였다.
이번 문제는 무방향 그래프이자, 사이클이 없는 트리의 특성을 이용한 지름 찾기 알고리즘이었다.
처음에는 트리에 지름이 있다고? 라는 의문이 들었지만 문제 설명을 자세히 보니 이해가 갔다.
루트 노드 에서 가장 거리가 먼 단말 노드 두 개를 찾고, 두 지점 간의 거리를 구하자.
처음엔 루트 노드에서 가장 거리가 먼 단말 노드를 두 개 찾고, 두 노드 간 거리를 구하면 된다 생각했다.
그러기 위해서는 먼저 두 개의 단말 노드가 루트 노드로부터 얼마나 멀리 떨어졌는지를 알아내야 했다.
따라서 처음 값을 입력 받았을 때, 간선의 가중치를 합산하여 루트 노드와 떨어진 거리로 재설정하였다.
루트 노드에서 가장 거리가 먼 노드들을 두 개 선별해보자.
List
에 저장한다.List
를 정렬하여 가장 거리가 먼 값을 구하고, 각각의 인덱스를 구해 노드를 구한다.import sys
read = sys.stdin.readline
N = int(read())
tree = [list() for _ in range(N + 1)]
distance = [0] * (N + 1)
# distance에는 루트 노드로부터 해당 노드까지의 가중치를 저장.
# 각 간선의 가중치와 부모 - 자식 노드를 연결시켜 tree에 추가함.
for _ in range(N-1):
p_node, c_node, dist = map(int, read().split())
tree[p_node].append(c_node)
distance[c_node] = distance[p_node] + dist
# 루트 노드로부터 가장 긴 단말 노드 2개를 선택한다.
f_len, s_len = sorted(distance, reverse=True)[:2]
f_node, s_node = (distance.index(f_len), distance.index(s_len))
# 루트 노드와 두 단말 노드 간의 겹치는 거리를 계산하여 제외해야 한다.
# 따라서 두 노드의 공통된 부모 노드를 찾는 함수를 선언한다.
def find_parent(f_nd, s_nd):
while f_nd != s_nd:
for i in range(N):
if f_nd in tree[i]:
f_nd = i
if s_nd in tree[i]:
s_nd = i
return f_nd
p_node = find_parent(f_node, s_node)
result = f_len + s_len - (distance[p_node] * 2)
print(result)
풀이는 그럴싸 했으나, 결론적으로 완전히 잘못된 접근 이었다.
루트 노드로부터 가장 멀리 떨어진 두 노드를 선별한 것은 좋았다. 여기까지는 괜찮았다.
하지만 가장 멀리 떨어진 두 노드가 트리의 지름 을 이루는 두 접점이라는 보장이 없었다.
따라서 이 문제는 올바른 접근법을 알아내기 위해 결국 외부 자료의 힘을 빌릴 수밖에 없었다.
트리의 지름 을 구하기 위해서는 아래의 방법대로 알고리즘을 설계해야 한다.
A
를 구한다.A
로부터 가장 멀리 떨어진 노드 B
를 구한다.A
와 노드 B
간의 거리가 바로 트리의 지름 이 된다.증명 공식은 https://bedamino.tistory.com/15 를 참고하였다. 아주 큰 도움이 되었다 (...)
import sys
from collections import deque
read = sys.stdin.readline
N = int(read())
tree = [list() for _ in range(N + 1)]
# 각 간선의 가중치와 부모 - 자식 노드를 연결시켜 tree에 추가함.
for _ in range(N-1):
p_node, c_node, dist = map(int, read().split())
tree[p_node].append((c_node, dist))
tree[c_node].append((p_node, dist))
# 루트 노드로부터 가장 거리가 먼 노드와, 거리를 구하는 함수를 작성.
# 해당 노드와 이어진 노드들의 목록을 덱에 추가하여 순회
def bfs(node):
farest_node, distance = 0, 0
queue = deque(tree[node])
visited = [False] * (N + 1)
visited[node] = True
while queue:
nd, dist = queue.popleft()
visited[nd] = True
# 만약 더 거리가 먼 노드를 찾았다면, 이를 업데이트 해야 함.
if dist > distance:
farest_node = nd
distance = dist
# 해당 노드까지의 거리에서 추가 가중치만큼을 더하여 덱에 추가.
for new_nd, new_dist in tree[nd]:
if not visited[new_nd]:
queue.append((new_nd, dist + new_dist))
return farest_node, distance
# 루트 노드로부터 가장 먼 노드만을 선별하여 구한다.
far_node, _ = bfs(1)
# 그 후, 가장 먼 노드로부터 가장 먼 노드와의 거리를 구한다.
_, radius = bfs(far_node)
print(radius)
bfs 함수를 설계할 때, 특정 노드를 입력 받으면 해당 노드로부터 가장 거리가 먼 노드를 반환시키게 했다.
추가로, 가장 거리가 먼 노드와의 거리 같이 저장하여 tuple
형식으로 리턴 값을 주도록 함수를 작성했다.
그 후로는 루트 노드 에서 가장 멀리 떨어진 노드인 far_node
만을 선별하여 구하고,
이를 토대로 다시 한번 BFS
탐색을 진행하여 트리의 지름인 Radius
값을 찾아내였다.
해당 알고리즘을 토대로 코드를 재작성하니, 마침내 정답 처리
를 해주셨다.
원리를 제대로 파악하지 못하면 풀기가 어려운 문제였다. 이제 확실히 알았으니 다음부터는 틀리지 말자.
요 근래 몸 상태가 부쩍 좋지 않아 공부에 통 집중을 못하고 있지만, 그래도 꼬박꼬박 풀어야 한다!