[백준/파이썬] 1967번: 트리의 지름

수박강아지·2025년 6월 17일

BAEKJOON

목록 보기
95/174

문제

https://www.acmicpc.net/problem/1967

풀이

  • 무방향 그래프
  • 트리의 지름 출력
    • 트리의 지름이란? 가장 먼 두 노드 사이의 거리

트리의 지름을 구하기 위해서 루트 노드(1번 노드)에서 가장 먼 노드를 구합니다.
여기서 구한 노드에서부터 가장 먼 거리를 구하면 이 값이 트리의 지름이 됩니다.

이 전략을 이용해서 문제를 풀어보겠습니다.
BFS를 이용해서 풀어봤습니다.

if __name__ == "__main__":
    n = int(input())
    graph = [[] for _ in range(n+1)] # 노드 번호가 1번부터 n번까지 존재하므로
    for _ in range(n-1): # 간선 정보 입력
        p,c,w = map(int,input().split()) # 부모, 자식, 가중치
        graph[p].append((c,w))
        graph[c].append((p,w)) # 무방향 그래프이므로 자식 노드 정보도 입력
  • 그래프의 정보 입력
def bfs(start):
    global n
    queue = deque([start])
    visited = [-1] * (n+1)
    visited[start] = 0
  • 방문한 거리를 이용해서 루트 노드로부터 총 거리 계산
  • 시작 노드는 거리 0
    while queue:
        cur = queue.popleft() # 현재 방문할 노드
        
        for node,weight in graph[cur]: # 연결된 노드와 가중치
            if visited[node] == -1: # 방문하지 않은 경우
                visited[node] = visited[cur] + weight # 거리 누적합
                queue.append(node) # 이웃 노드 큐에 추가
    
    return visited.index(max(visited)), max(visited)
  • 방문할 노드 추출 후 방문하지 않은 노드일 경우 거리 추가
  • 끝나면 가장 먼 노드 번호, 가장 멀리 떨어진 거리 return
    idx,_ = bfs(1)
    _,dis = bfs(idx)
    
    print(dis) # 거리 출력

왜 2번이나 BFS를 수행하는가?

위에 작성한 전략을 확인하면, '루트에서 가장 멀리 있는 노드를 찾은 후, 가장 먼 노드에서 가장 먼 거리를 구한다.'라고 언급했습니다.

그러기 위해서 루트 노드(1번)에서부터 가장 멀리 있는 노드를 탐색 후,
가장 멀리 있는 노드의 값에서 가장 멀리 있는 값을 찾으면 트리의 지름을 구할 수 있게 되기 때문입니다.

코드

from collections import deque
import sys
input = sys.stdin.readline

def bfs(start):
    global n
    queue = deque([start])
    visited = [-1] * (n+1)
    visited[start] = 0
    
    while queue:
        cur = queue.popleft()
        
        for node,weight in graph[cur]:
            if visited[node] == -1:
                visited[node] = visited[cur] + weight
                queue.append(node)
    
    return visited.index(max(visited)), max(visited)

if __name__ == "__main__":
    n = int(input())
    graph = [[] for _ in range(n+1)]
    for _ in range(n-1):
        p,c,w = map(int,input().split())
        graph[p].append((c,w))
        graph[c].append((p,w))
        
    idx,_ = bfs(1)
    _,dis = bfs(idx)
    
    print(dis)

0개의 댓글