풀이 시간
- 44m
- 처음에 dfs로 풀었는데 Recursion Error랑 시간 초과가 나서 갈아엎고 bfs로 풀었다
구현 방식
- 그냥 트리의 지름을 구하기 위해선 루트노드에서 가장 먼 노드(farest_node_1)을 찾고, farest_node_1에서 다시 가장 먼 노드(farest_node_2)를 찾으면 된다
(-> visit 리스트에 거리를 저장함)
- 이 문제는 두 번째로 긴 트리의 지름을 구하는 문제이기 때문에, 1) farest_node_2를 제거한 후에 farest_node_1에서 가장 먼 노드를 찾은 후의 최장 거리의 값(candidate1)을 구하고, 2) farest_node_1을 제거한 후에 farest_node_2에서 가장 먼 노드를 찾은 후의 최장 거리의 값(candidate2)를 구해야한다. candidate1과 candidate2 중 더 큰 값이 두 번째 트리의 지름이다
시간 초과 코드 (dfs로 구현)
import sys
sys.setrecursionlimit(10**8)
def dfs(curr, weight):
for element in graph[curr]:
nnode, nweight = element
if visit[nnode] == -1:
visit[nnode] = weight + nweight
dfs(nnode, visit[nnode])
def solution_dfs(curr, weight, ban_node):
for element in graph[curr]:
nnode, nweight = element
if nnode != ban_node:
if visit[nnode] == -1:
visit[nnode] = weight + nweight
solution_dfs(nnode, visit[nnode], ban_node)
N = int(sys.stdin.readline()[:-1])
graph = dict()
for n in range(N-1):
curr, next, weight = map(int, sys.stdin.readline()[:-1].split())
if curr in graph:
graph[curr].append((next, weight))
elif curr not in graph:
graph[curr] = [(next, weight)]
if next in graph:
graph[next].append((curr, weight))
elif next not in graph:
graph[next] = [(curr, weight)]
visit = [-1] * (N+1)
visit[1] = 0
dfs(1, 0)
farest_node_1 = visit.index(max(visit))
visit = [-1] * (N+1)
visit[farest_node_1] = 0
dfs(farest_node_1, 0)
farest_node_2 = visit.index(max(visit))
visit = [-1] * (N+1)
visit[farest_node_1] = 0
solution_dfs(farest_node_1, 0, farest_node_2)
candidate1 = max(visit)
visit = [-1] * (N+1)
visit[farest_node_2] = 0
solution_dfs(farest_node_2, 0, farest_node_1)
candidate2 = max(visit)
print(max(candidate1, candidate2))
최종 코드 (bfs로 구현)
import sys
from collections import deque
def bfs(curr, weight):
queue = deque([])
queue.append((curr, weight))
visit = [-1] * (N+1)
visit[curr] = 0
while queue:
curr, weight = queue.popleft()
for nnode, nweight in graph[curr]:
if visit[nnode] == -1:
visit[nnode] = weight + nweight
queue.append((nnode, visit[nnode]))
return visit.index(max(visit))
def solution_bfs(curr, weight, ban_node):
queue = deque([])
queue.append((curr, weight))
visit = [-1] * (N+1)
visit[curr] = 0
while queue:
curr, weight = queue.popleft()
for nnode, nweight in graph[curr]:
if nnode != ban_node:
if visit[nnode] == -1:
visit[nnode] = weight + nweight
queue.append((nnode, visit[nnode]))
return max(visit)
N = int(sys.stdin.readline()[:-1])
graph = dict()
for n in range(N-1):
curr, next, weight = map(int, sys.stdin.readline()[:-1].split())
if curr in graph:
graph[curr].append((next, weight))
elif curr not in graph:
graph[curr] = [(next, weight)]
if next in graph:
graph[next].append((curr, weight))
elif next not in graph:
graph[next] = [(curr, weight)]
farest_node_1 = bfs(1, 0)
farest_node_2 = bfs(farest_node_1, 0)
candidate1 = solution_bfs(farest_node_1, 0, farest_node_2)
candidate2 = solution_bfs(farest_node_2, 0, farest_node_1)
print(max(candidate1, candidate2))
결과
- 처음에 Recursion Error가 발생해서 Recursion limit을 설정해주고 다시 돌렸는데도 시간 초과가 남
- bfs로 다시 코드 짜서 돌리니까 한큐에 성공했다