트리(tree)는 사이클이 없는 무방향 그래프이다. 트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다. 트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다. 이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.
이런 두 노드 사이의 경로의 길이를 트리의 지름이라고 한다. 정확히 정의하자면 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.
입력으로 루트가 있는 트리를 가중치가 있는 간선들로 줄 때, 트리의 지름을 구해서 출력하는 프로그램을 작성하시오. 아래와 같은 트리가 주어진다면 트리의 지름은 45가 된다.
트리의 노드는 1부터 n까지 번호가 매겨져 있다.
파일의 첫 번째 줄은 노드의 개수 n(1 ≤ n ≤ 10,000)이다. 둘째 줄부터 n-1개의 줄에 각 간선에 대한 정보가 들어온다. 간선에 대한 정보는 세 개의 정수로 이루어져 있다. 첫 번째 정수는 간선이 연결하는 두 노드 중 부모 노드의 번호를 나타내고, 두 번째 정수는 자식 노드를, 세 번째 정수는 간선의 가중치를 나타낸다. 간선에 대한 정보는 부모 노드의 번호가 작은 것이 먼저 입력되고, 부모 노드의 번호가 같으면 자식 노드의 번호가 작은 것이 먼저 입력된다. 루트 노드의 번호는 항상 1이라고 가정하며, 간선의 가중치는 100보다 크지 않은 양의 정수이다.
첫째 줄에 트리의 지름을 출력한다.
12
1 2 3
1 3 2
2 4 5
3 5 11
3 6 9
4 7 1
4 8 7
5 9 15
5 10 4
6 11 6
6 12 10
45
트리에 가중치가 붙어있어서 흔히 다익스트라나 플로이드 워셜 문제로 헷갈릴 수가 있다. 하지만
"그래프일 경우"
간선의 가중치가 모두 동일 -> BFS
그렇지 않은 경우 -> 다익스트라
"트리인 경우"
DFS or BFS
유의하도록 하자.
필자는 이 사실을 모르고 보자마자 바로 플로이드 워셜로 풀었다.
답은 나오지만 메모리 초과가 발생한다.
심지어 플로이드 워셜의 시간 복잡도는 O(n^3)이다. 문제에서 입력값의 범위는 10000개가 최대라고 하였는데, 10000개는 보통 O(n^2)의 시간 복잡도를 가지는 알고리즘을 이용하여 풀어야 한다.
그렇다면 또 하나의 의문이 생길 수 있다.
다익스트라를 이용하여 풀면 되지 않을까?
다익스트라의 시간 복잡도는 O(nlogn)이기에 적합하다고 생각할 수 있다.
맞다. 다익스트라로도 풀 수 있다. 조금있다가 문제 풀이에서 보여줄 것이지만 풀 수는 있다.
하지만 이 문제는 특별 케이스이다.
만약 다익스트라로 접근을 한다면 보통 플로이드 워셜처럼 모든 노드에서 갈 수 있는 최대 경로나 최단 경로를 구해야 한다. 그렇다면 시간 복잡도는 O(n^2logn)이기 때문에 시간초과가 뜬다.
이 문제에서는 굳이 n개의 노드를 전부 다익스트라를 사용하지 않아도 되므로 통과하기는 한다. 하지만 트리에서 가중치가 되어있으면 dfs나 bfs로 접근하도록 하자.
import sys
N = int(sys.stdin.readline())
INF = 10 ** 9
graph = [[INF] * (N + 1) for _ in range(N + 1)]
for i in range(1, N + 1):
for j in range(1, N + 1):
if i == j:
graph[i][j] = 0
for i in range(N - 1):
a, b, c = map(int, sys.stdin.readline().split())
graph[a][b] = c
graph[b][a] = c
for k in range(1, N + 1):
for i in range(1, N + 1):
for j in range(1, N + 1):
graph[i][j] = min(graph[i][j], graph[i][k] + graph[k][j])
max_list = []
max_val = 0
for i in range(1, N + 1):
for j in graph[i]:
if j == INF:
continue
else:
if max_val < j:
max_val = j
max_list.append(max_val)
max_val = 0
print(max(max_list))
import sys
sys.setrecursionlimit(10**6)
N = int(sys.stdin.readline())
graph = [[] for _ in range(N + 1)]
for i in range(1, N):
a, b, c = map(int, sys.stdin.readline().split())
graph[a].append([b, c])
graph[b].append([a, c])
distance = [-1] * (N + 1)
distance[1] = 0
def dfs(start, dis):
for i in range(len(graph[start])):
next, next_dis = graph[start][i]
if distance[next] == -1:
distance[next] = next_dis + dis
dfs(next, distance[next])
dfs(1, 0)
max_index = distance.index(max(distance))
distance = [-1] * (N + 1)
distance[max_index] = 0
dfs(max_index, 0)
print(max(distance))
import sys
import heapq
INF = int(1e9)
N = int(sys.stdin.readline())
graph = [[] for i in range(N + 1)]
visited = [False] * (N + 1)
distance = [INF] * (N + 1)
for _ in range(N - 1):
a, b, c = map(int, sys.stdin.readline().split())
graph[a].append((b, c))
graph[b].append((a, c))
def dijkstra(start):
q = []
heapq.heappush(q, (0, start))
distance[start] = 0
while q:
dist, now = heapq.heappop(q)
if distance[now] < dist:
continue
for i in graph[now]:
cost = dist + i[1]
if cost < distance[i[0]]:
distance[i[0]] = cost
heapq.heappush(q, (cost, i[0]))
dijkstra(1)
distance[0] = -1
new_start = distance.index(max(distance))
distance = [INF] * (N + 1)
visited = [False] * (N + 1)
dijkstra(new_start)
distance[0] = -1
print(max(distance))
다익스트라