[BOJ] 1967. 트리의 지름 (🥇, 트리/DFS)

lemythe423·2023년 8월 23일
0

BOJ 문제풀이

목록 보기
42/133
post-thumbnail

🔗

풀이

🫠 DFS 사용한 애매한 풀이

일단 이진 트리 형태는 아니다. 차수가 3을 넘어갈 수 있다.

기본적으로 특정 루트 노드에서 자식 노드들로 dfs 탐색을 하게 되면 각 자식노드 ~ 리프노드까지의 지름의 길이를 구할 수 있다. 그렇게 구해진 여러 자식 노드들의 반환값들 중에 최대값 2개를 구해서 더하면 그 루트 노드에서 구할 수 있는 최대 지름이 된다. 이걸 모든 노드에 대해 반복하면 전체 트리에서 가장 긴 지름의 길이를 구할 수 있다.

# 5704ms

import sys
from collections import defaultdict
sys.setrecursionlimit(40000)
def dfs(parent):
    ans = 0
    for child, weight in graph[parent]:
        ans = max(ans, dfs(child)+weight)

    return ans

# 그래프 초기화
graph = defaultdict(list)

n = int(input())
for _ in range(n-1):
    p, c, w = map(int, input().split())
    graph[p].append((c, w))

ans = 0
for i in range(1, n+1):
    res_lst = sum(sorted([dfs(c) + w for c, w in graph[i]], reverse=True)[:2])
    ans = max(ans, res_lst)

print(ans)

🤩 DFS 2번만 사용하는 풀이

대다수의 빠른 시간을 내는 풀이는 이 풀이였다

임의의 정점 x에서 가장 먼 거리의 노드 y를 찾고, 그 y에서 다시 먼 거리의 노드 z를 찾는다. y-z를 연결하는 지름이 트리의 지름이 된다.

블로그를 참고했다. 수학적 귀납법으로 이 명제를 증명했다. 읽어보면 이해가 안 되는 건 아닌데 너무 수학적이라서 조금 감(?)으로 이해를 해봤다.

일단 이 문제는 가중치를 구하는 문제이다. 이 문제에서 가장 먼 노드 = 한 노드에서 다른 노드에 도달하며 얻는 가중치의 합이 가장 큰 노드이다. 특정 노드에 도달할 때 얻을 수 있는 가중치는 어떤 노드에서 출발하느냐에 상관없이 정해져 있다. 사실 가중치가 아니라 지나쳐야 하는 자식 노드의 개수(깊이), 또는 지나쳐야 하는 일반 노드의 개수였다면 어떤 노드에서 출발하는지가 중요한 문제가 됐겠지만, 가중치의 문제로 바뀌게 되면 어떤 노드에서 출발하는지는 별로 중요하지 않다.

이 문제에서 볼 때 노드9 까지 가기 위한 거리나, 12까지 가기 위한 거리는 반드시 15와 10 가중치를 얻게 된다. 5를 거쳐 9로 가는 길, 6을 거쳐 12로 가는 길도 큰 가중치를 얻을 수 있다. 이건 3에서 출발하든, 4에서 출발하든 어디에서 출발하느냐에 상관없이 이미 정해져 있는 것이다.

그렇기 때문에 어떤 노드에서 출발하든 가장 먼 노드는 이미 가중치 합의 큰 값을 가질 수 있는 경로를 지닌 노드가 될 수 밖에 없다. 그리고 그 노드에 도달한 후에 또 같은 방식으로 가장 거리가 먼 노드를 찾게 되면, 결국 서로 가장 먼 거리에 있는 노드를 찾을 수 있게 되는 것이다. 가장 먼 거리에서 가장 먼 거리를 찾기 때문에...

즉 dfs 2번으로 해결할 수 있는 문제다.

우선 루트 노드1에서(아무 노드여도 상관없음) 가장 먼 노드를 찾고,
그 노드에서 다시 가장 먼 노드를 찾았다. dfs 함수는 모든 노드들에 대한 거리를 찾도록 구현했다. 하지만 기본적으로 파이썬의 재귀 횟수 제한을 넘어서기 때문에 오래걸리는 것 같다...

import sys
sys.setrecursionlimit(10**7)

def dfs(now, cost, visited):
    visited[now] = cost
    
    for past, w in graph[now]:
        if visited[past] == -1:
            visited = dfs(past, cost+w, visited)

    return visited

from collections import defaultdict

# 그래프 초기화
graph = defaultdict(list)

n = int(input())
for _ in range(n-1):
    p, c, w = map(int, input().split())
    graph[p].append((c, w))
    graph[c].append((p, w))

visited1 = [-1]*(n+1)
dist1 = dfs(1, 0, visited1)
a = dist1.index(max(dist1))

visited2 = [-1]*(n+1)
dist2 = dfs(a, 0, visited2)
print(max(dist2))

😎 스택을 활용한 풀이

재귀보다 훨씬 빠르다

# 88ms

def dfs(i):
    visited = [-1] * (n+1)
    visited[i] = 0
    stack = [i]

    while stack:
        node = stack.pop()

        for past, w in graph[node]:
            if visited[past] == -1:
                visited[past] = visited[node] + w
                stack.append(past)

    return visited

def solution():
    dist = dfs(1)
    return max(dfs(dist.index(max(dist))))

from collections import defaultdict

graph = defaultdict(list)

n = int(input())
for _ in range(n-1):
    p, c, w = map(int, input().split())
    graph[p].append((c, w))
    graph[c].append((p, w))

print(solution())
profile
아무말이나하기

0개의 댓글