https://www.acmicpc.net/problem/1167
처음에 문제를 보고, 가장 긴 거리? 최단 경로 문제를 반대로 적용한건가? 라고 생각해서 플루워셜 알고리즘을 썼다. (지금 생각해보면 아무 논리도 없었음..) 그러다가 답이 맞지 않게 나온다는 것을 깨닫고 바로 다 지웠다.
DFS를 사용했다. (BFS 사용해도 된다)
n = int(input())
graph = [[] for _ in range(n+1)]
visited = [False]*(n+1)
distance = [0]*(n+1)
for i in range(n):
array = list(map(int, input().split()))[:-1]
a = array[0]
for j in range(1, len(array), 2):
graph[a].append((array[j], array[j+1]))
def dfs(a, b):
visited[a] = True
for i in graph[a]:
node, dist = i
if visited[node] == False:
distance[node] = max(distance[node], b+dist)
dfs(node, max(distance[node], b+dist))
dfs(1, 0)
print(max(distance))
하지만 dfs(2,0)으로 바꿀 경우 다른 값이 나오는 문제가 발생했다. 따라서 for문을 돌면서 가장 큰 값을 도출하게 바꾸어주었다.
import sys
input = sys.stdin.readline
n = int(input())
graph = [[] for _ in range(n+1)]
visited = [False]*(n+1)
distance = [0]*(n+1)
for i in range(n):
array = list(map(int, input().split()))[:-1]
a = array[0]
for j in range(1, len(array), 2):
graph[a].append((array[j], array[j+1]))
def dfs(a, b):
visited[a] = True
for i in graph[a]:
node, dist = i
if visited[node] == False:
distance[node] = max(distance[node], b+dist)
dfs(node, max(distance[node], b+dist))
result = 0
for i in range(1, n+1):
dfs(i, 0)
result = max(result, max(distance))
visited = [False]*(n+1)
distance = [0]*(n+1)
print(result)
하지만 당연하게도 시간초과...
효율적으로 코드를 완성하기 위해서는 수학적 개념을 하나 알고 있어야 한다.
노드 a에서 가장 먼 노드를 b라고 하면, b는 트리의 지름을 이루는 노드 중 하나이다.
증명은 https://blog.myungwoo.kr/112 참고
따라서 DFS를 실행시켜 가장 거리가 먼 노드를 찾은 후, 그 찾은 노드로 한번 더 DFS를 실행시켜야 한다.
import sys
input = sys.stdin.readline
n = int(input())
graph = [[] for _ in range(n+1)]
visited = [False]*(n+1)
distance = [0]*(n+1)
for i in range(n):
array = list(map(int, input().split()))[:-1]
a = array[0]
for j in range(1, len(array), 2):
graph[a].append((array[j], array[j+1]))
def dfs(a, b):
visited[a] = True
for i in graph[a]:
node, dist = i
if visited[node] == False:
distance[node] = max(distance[node], b+dist)
dfs(node, max(distance[node], b+dist))
result = 0
idx = 1
dfs(1,0)
for i in range(1, n+1):
if result < distance[i]:
result = distance[i]
idx = i
visited = [False]*(n+1)
distance = [0]*(n+1)
dfs(idx,0)
print(max(distance))