오답 코드1 (시간 초과, dijkstra)
import sys
import heapq
n = int(input())
node = [[] for i in range(n + 1)]
for i in range(n - 1):
a, b, c = map(int, sys.stdin.readline().split())
node[a].append([b, c])
node[b].append([a, c])
def dijkstra(start, dest):
dest[start] = 0
heap = []
heapq.heappush(heap, [0, start])
while(len(heap)):
curr = heapq.heappop(heap)
if dest[curr[1]] < curr[0]:
continue
for n in node[curr[1]]:
nex = n[0]
nexdist = n[1] + curr[0]
if nexdist < dest[nex]:
dest[nex] = nexdist
heapq.heappush(heap, [nexdist, nex])
for i in range(1, n + 1):
dest = [987654321 for i in range(n + 1)]
dest[0] = 0
dijkstra(i,dest)
print(sum(dest))
오답 코드2 (시간 초과, BFS)
import sys
import copy
n = int(input())
node = [[] for i in range(n + 1)]
for i in range(n - 1):
a, b, c = map(int, sys.stdin.readline().split())
node[a].append([b, c])
node[b].append([a, c])
for i in range(1, n + 1):
v = [False for j in range(n + 1)]
queue = [[i, 0]]
res = 0
while len(queue):
curr = queue[0]
v[curr[0]] = True
print('current = ', curr[0])
del queue[0]
for n1 in node[curr[0]]:
if not v[n1[0]]:
res += n1[1] + curr[1]
queue.append([n1[0], n1[1] + curr[1]])
print(res)
정답 코드 (DFS)
import sys
sys.setrecursionlimit(10**8)
def dfs1(current, parent):
subtreeSize[current] = 1
for i in range(len(node[current])):
child = node[current][i][0]
weight = node[current][i][1]
if child != parent:
dfs1(child, current)
distSum[current] += distSum[child] + subtreeSize[child] * weight
subtreeSize[current] += subtreeSize[child]
return
def dfs2(current, parent):
for i in range(len(node[current])):
child = node[current][i][0]
weight = node[current][i][1]
if child != parent:
distSum[child] = distSum[current] + weight * (n - 2 * subtreeSize[child])
dfs2(child, current)
return
n = int(sys.stdin.readline())
node = [[] for i in range(n + 1)]
subtreeSize = [0 for i in range(n + 1)]
distSum = [0 for i in range(n + 1)]
for i in range(n - 1):
a, b, c = map(int, sys.stdin.readline().split())
node[a].append([b, c])
node[b].append([a, c])
dfs1(1, 1)
dfs2(1, 1)
for i in range(1, n + 1):
print(distSum[i])