시간 제한 | 메모리 제한 | 제출 | 정답 | 맞힌 사람 | 정답 비율 |
---|---|---|---|---|---|
2 초 | 128 MB | 44009 | 17559 | 13305 | 41.000% |
트리(tree)는 사이클이 없는 무방향 그래프이다. 트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다. 트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다. 이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.
이런 두 노드 사이의 경로의 길이를 트리의 지름이라고 한다. 정확히 정의하자면 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.
입력으로 루트가 있는 트리를 가중치가 있는 간선들로 줄 때, 트리의 지름을 구해서 출력하는 프로그램을 작성하시오. 아래와 같은 트리가 주어진다면 트리의 지름은 45가 된다.
트리의 노드는 1부터 n까지 번호가 매겨져 있다.
파일의 첫 번째 줄은 노드의 개수 n(1 ≤ n ≤ 10,000)이다. 둘째 줄부터 n-1개의 줄에 각 간선에 대한 정보가 들어온다. 간선에 대한 정보는 세 개의 정수로 이루어져 있다. 첫 번째 정수는 간선이 연결하는 두 노드 중 부모 노드의 번호를 나타내고, 두 번째 정수는 자식 노드를, 세 번째 정수는 간선의 가중치를 나타낸다. 간선에 대한 정보는 부모 노드의 번호가 작은 것이 먼저 입력되고, 부모 노드의 번호가 같으면 자식 노드의 번호가 작은 것이 먼저 입력된다. 루트 노드의 번호는 항상 1이라고 가정하며, 간선의 가중치는 100보다 크지 않은 양의 정수이다.
첫째 줄에 트리의 지름을 출력한다.
import sys
sys.setrecursionlimit(100_000)
# initailize variables
# n: number of nodes
n = int(sys.stdin.readline())
# tree: tree data structure which is given in this problem
tree = [[] for _ in range(n+1)]
for _ in range(n-1):
parent, child, weight = map(int, sys.stdin.readline().split())
tree[parent].append((child, weight))
# diameters: diameter of each subtree
diameters = [[0, 0] for _ in range(n+1)]
# max_length_to_leaf: max length to each subtree's leaf node
max_length_to_leaf = [0 for _ in range(n+1)]
def solve(parent, tree, diameters, max_length_to_leaf):
if len(tree[parent]) == 0:
return
for child, weight in tree[parent]:
solve(child, tree, diameters, max_length_to_leaf)
diameters[parent] = sorted(
diameters[parent] + [max_length_to_leaf[child] + weight], reverse=True)[0:2]
max_length_to_leaf[parent] = max(
max_length_to_leaf[parent], max_length_to_leaf[child] + weight)
solve(1, tree, diameters, max_length_to_leaf)
print(max(map(sum, diameters)))
이 문제를 풀기 위해 먼저 생각했던 방법은 루트 노드의 아래쪽으로 dfs를 적용하는 것이었지만, 결국은 재귀와 DP를 이용한 풀이를 하게 되었다.
이 문제를 보자마자 트리의 지름 양끝에 있는 노드가 멀리 떨어져 있을 것이라고 생각할 수도 있겠으나, 극단적인 경우에는 트리의 지름이 3개의 노드로만 이뤄져 있을 수도 있다.
트리의 지름이 트리 간선을 지나간 횟수로 정의되는 것이 아니라, 가중치의 합으로 정의되기 때문이다.
예를 들어, 아래와 같은 경우를 보면
4
1 2 1
1 3 2
3 4 100
3 5 200
트리의 지름은 노드 4번-3번-5번의 경로를 지나는 300이다.
일반적으로 트리의 지름이 2번-1번-3번-4번(또는 5번)으로 정의되는데 주의가 필요한 것이다.
이러한 특성 때문에, 모든 노드에서 그 노드를 루트로 하는 서브트리에서 (루트를 경유하는 경로의)트리의 지름이 형성되는지 확인해봐야 한다.
이를 위해, 나는 max_length_to_leaf
라는 메모이제이션 용 리스트를 만들어 이를 이용해 각 서브트리의 지름 diameters
를 구했다.
diameters
에는 실제로는 자식 노드 중 max_length_to_leaf
값이 가장 큰 두 개의 max_length_to_leaf
값이 저장된다.
# diameters: diameter of each subtree
diameters = [[0, 0] for _ in range(n+1)]
# max_length_to_leaf: max length to each subtree's leaf node
max_length_to_leaf = [0 for _ in range(n+1)]
max_length_to_leaf
가 필요한 이유는, 이를 통해 서브트리의 지름을 더욱 빠르게 구할 수 있기 때문이다.
반복된 연산을 최소화하여 빠르게 연산하기 위해 사용된다. (DP)
4
1 2 1
1 3 2
3 4 100
3 5 200
예를 들어, 위와 같은 예제에서 1번 노드의 경우에는 트리의 지름이 1 + 2 + max_length_to_leaf[2] + max_length_to_leaf[3]으로 구해질 수 있다.
일반화하면, 그 노드를 루트로 하는 서브트리의 (루트를 경유하는) 지름은, 가장 큰, 그리고 그 다음으로 큰 max_length_to_leaf 값을 갖는 서로 다른 두 자식 노드로의 edge의 weight 값의 합과 max_length_to_leaf 값의 합으로 구할 수 있다고 할 수 있다.
아래의 solve()
에서는 전체 트리의 leaf node쪽에서부터 위쪽으로 max_length_to_leaf를 구하면서 올라온다.
def solve(parent, tree, diameters, max_length_to_leaf):
if len(tree[parent]) == 0:
return
for child, weight in tree[parent]:
solve(child, tree, diameters, max_length_to_leaf)
diameters[parent] = sorted(
diameters[parent] + [max_length_to_leaf[child] + weight], reverse=True)[0:2]
max_length_to_leaf[parent] = max(
max_length_to_leaf[parent], max_length_to_leaf[child] + weight)
solve()
는 재귀 함수로, 만약 leaf node에 다다르면 return한다.
모든 트리의 노드를 순회하며 diameters
와 max_length_to_leaf
를 업데이트한다.
순회가 종료되면, 우리는 diameters
의 값 중 둘을 더했을 때 최댓값이 되는 것을 정답으로 제출하면 된다.
print(max(map(sum, diameters)))