import sys
input=sys.stdin.readline
sys.setrecursionlimit(10**9)
N=int(input())
Tree=[ [] for _ in range(N+1) ]
def DFS(start):
for i,j in Tree[start]:
if not visit[i]:
visit[i]=True
distance[i]=distance[start]+j
DFS(i)
for i in range(N-1):
u,v,k=map(int,input().split())
Tree[u].append( (v,k) )
Tree[v].append( (u,k) )
distance=[0]*(N+1)
visit=[False]*(N+1)
visit[1]=True
DFS(1) #루트 노드의 번호는 항상 1이라고 가정한다.
check=distance.index(max(distance)) # 처음 탐색의 가장 긴 길이를 확인해준다.
distance=[0]*(N+1) # distance 초기화
visit=[False]*(N+1)
visit[check]=True
DFS(check)
print(max(distance))
📌 어떻게 접근할 것인가?
트리의 지름을 구하는 방법은 루트노드에서 가장 거리가 먼 지점 를 구한뒤 트리의 지점에서 거리가 가장 먼 지점이 트리의 지름이 됩니다.
이것을 증명하기에는 너무 복잡하기도 하고 깔끔하게 정리하신 분이 있기 때문에
bedamino님의 블로그 를 참조하거나
귀류법으로 증명하신 구사과님의 블로그 님의 블로그를 참조하는걸 추천합니다.
이번 글에서는 증명은 하지않고 코드에 대해서 자세히 살표 보겠습니다.
✅ 코드
먼저 트리를 입력받을때 거리값도 함께 입력받습니다.
그리고 총 3가지를 선언해줍니다.
distance 배열은 거리를 저장하고 visit 배열은 방문 체크를 합니다. 이때 문제에서 루트노드는 무조건 1 이라고 정했으므로 미리 visit 인덱스 1의 값은 방문 처리해줍니다.
def DFS(start):
for i,j in Tree[start]:
if not visit[i]:
visit[i]=True
distance[i]=distance[start]+j
DFS(i)
DFS(1)
이후 를 탐색 해줍니다. distance[i]=distance[start]+j 를 통해서 매번 노드 사이의 거리를 구해줍니다.
루트노드는 1이기 때문에 처음 start 값은 1로 정해줍니다.
이후 한번번의 탐색으로 루트노드에서 가장 거리가 먼 지점 을 distance.index(max(distance)) 로 통해 구할수 있습니다.
이후 이 지점을 루트노드라고 생각하고 다시 3가지 조건을 선언해줍니다.
distance 배열을 초기화해주고 visit 배열도 초기화해줍니다.
이때 중요한것은 루트노드를 1이 아니라 distance 의 최대값인덱스로 선언해줬으므로
visit[check] 를 방문처리 해줍니다.
이후 다시 를 탐색하면 트리의 지름을 구할 수 있습니다.