19581: 두 번째 트리의 지름

ewillwin·2023년 7월 26일
0

Problem Solving (BOJ)

목록 보기
151/230

풀이 시간

  • 44m
  • 처음에 dfs로 풀었는데 Recursion Error랑 시간 초과가 나서 갈아엎고 bfs로 풀었다

구현 방식

  • 그냥 트리의 지름을 구하기 위해선 루트노드에서 가장 먼 노드(farest_node_1)을 찾고, farest_node_1에서 다시 가장 먼 노드(farest_node_2)를 찾으면 된다
    (-> visit 리스트에 거리를 저장함)

  • 이 문제는 두 번째로 긴 트리의 지름을 구하는 문제이기 때문에, 1) farest_node_2를 제거한 후에 farest_node_1에서 가장 먼 노드를 찾은 후의 최장 거리의 값(candidate1)을 구하고, 2) farest_node_1을 제거한 후에 farest_node_2에서 가장 먼 노드를 찾은 후의 최장 거리의 값(candidate2)를 구해야한다. candidate1과 candidate2 중 더 큰 값이 두 번째 트리의 지름이다

시간 초과 코드 (dfs로 구현)

import sys
sys.setrecursionlimit(10**8)

def dfs(curr, weight):
    for element in graph[curr]:
        nnode, nweight = element
        if visit[nnode] == -1:
            visit[nnode] = weight + nweight
            dfs(nnode, visit[nnode])

def solution_dfs(curr, weight, ban_node):
    for element in graph[curr]:
        nnode, nweight = element
        if nnode != ban_node:
            if visit[nnode] == -1:
                visit[nnode] = weight + nweight
                solution_dfs(nnode, visit[nnode], ban_node)


N = int(sys.stdin.readline()[:-1])
graph = dict()
for n in range(N-1):
    curr, next, weight = map(int, sys.stdin.readline()[:-1].split())
    if curr in graph:
        graph[curr].append((next, weight))
    elif curr not in graph:
        graph[curr] = [(next, weight)]
    if next in graph:
        graph[next].append((curr, weight))
    elif next not in graph:
        graph[next] = [(curr, weight)]


visit = [-1] * (N+1)
visit[1] = 0
dfs(1, 0)

farest_node_1 = visit.index(max(visit))

visit = [-1] * (N+1)
visit[farest_node_1] = 0
dfs(farest_node_1, 0)

farest_node_2 = visit.index(max(visit))

# farest_node_2를 제거한 후에 farest_node_1에서 가장 먼 노드 찾기
visit = [-1] * (N+1)
visit[farest_node_1] = 0
solution_dfs(farest_node_1, 0, farest_node_2)
candidate1 = max(visit)

# farest_node_1을 제거한 후에 farest_node_2에서 가장 먼 노드 찾기
visit = [-1] * (N+1)
visit[farest_node_2] = 0
solution_dfs(farest_node_2, 0, farest_node_1)
candidate2 = max(visit)

print(max(candidate1, candidate2))

최종 코드 (bfs로 구현)

import sys
from collections import deque

def bfs(curr, weight):
    queue = deque([])
    queue.append((curr, weight))
    visit = [-1] * (N+1)
    visit[curr] = 0

    while queue:
        curr, weight = queue.popleft()
        for nnode, nweight in graph[curr]:
            if visit[nnode] == -1:
                visit[nnode] = weight + nweight
                queue.append((nnode, visit[nnode]))

    return visit.index(max(visit))

def solution_bfs(curr, weight, ban_node):
    queue = deque([])
    queue.append((curr, weight))
    visit = [-1] * (N+1)
    visit[curr] = 0

    while queue:
        curr, weight = queue.popleft()
        for nnode, nweight in graph[curr]:
            if nnode != ban_node:
                if visit[nnode] == -1:
                    visit[nnode] = weight + nweight
                    queue.append((nnode, visit[nnode]))

    return max(visit)


N = int(sys.stdin.readline()[:-1])
graph = dict()
for n in range(N-1):
    curr, next, weight = map(int, sys.stdin.readline()[:-1].split())
    if curr in graph:
        graph[curr].append((next, weight))
    elif curr not in graph:
        graph[curr] = [(next, weight)]
    if next in graph:
        graph[next].append((curr, weight))
    elif next not in graph:
        graph[next] = [(curr, weight)]

farest_node_1 = bfs(1, 0)
farest_node_2 = bfs(farest_node_1, 0)

# farest_node_2를 제거한 후에 farest_node_1에서 가장 먼 노드 찾기
candidate1 = solution_bfs(farest_node_1, 0, farest_node_2)

# farest_node_1을 제거한 후에 farest_node_2에서 가장 먼 노드 찾기
candidate2 = solution_bfs(farest_node_2, 0, farest_node_1)

print(max(candidate1, candidate2))

결과

  • 처음에 Recursion Error가 발생해서 Recursion limit을 설정해주고 다시 돌렸는데도 시간 초과가 남
  • bfs로 다시 코드 짜서 돌리니까 한큐에 성공했다
profile
💼 Software Engineer @ LG Electronics | 🎓 SungKyunKwan Univ. CSE

0개의 댓글