Softeer [21년 재직자 대회 본선] 거리 합 구하기

Yibangwon·2022년 9월 16일
0

알고리즘 문제풀이

목록 보기
57/60

오답 코드1 (시간 초과, dijkstra)

import sys
import heapq

n = int(input())
node = [[] for i in range(n + 1)]

for i in range(n - 1):
    a, b, c = map(int, sys.stdin.readline().split())
    node[a].append([b, c])
    node[b].append([a, c])

def dijkstra(start, dest):
    dest[start] = 0
    heap = []
    heapq.heappush(heap, [0, start])

    while(len(heap)):
        curr = heapq.heappop(heap)

        if dest[curr[1]] < curr[0]:
            continue
        for n in node[curr[1]]:
            nex = n[0]
            nexdist = n[1] + curr[0]
            if nexdist < dest[nex]:
                dest[nex] = nexdist
                heapq.heappush(heap, [nexdist, nex])


for i in range(1, n + 1):
    dest = [987654321 for i in range(n + 1)]
    dest[0] = 0
    dijkstra(i,dest)
    print(sum(dest))

오답 코드2 (시간 초과, BFS)

import sys
import copy

n = int(input())
node = [[] for i in range(n + 1)]

for i in range(n - 1):
    a, b, c = map(int, sys.stdin.readline().split())
    node[a].append([b, c])
    node[b].append([a, c])

for i in range(1, n + 1):
    v = [False for j in range(n + 1)]
    queue = [[i, 0]]
    res = 0
    while len(queue):
        curr = queue[0]
        v[curr[0]] = True
        print('current = ', curr[0])
        del queue[0]
        for n1 in node[curr[0]]:
            if not v[n1[0]]:
                res += n1[1] + curr[1]
                queue.append([n1[0], n1[1] + curr[1]])
    print(res)

정답 코드 (DFS)

import sys
sys.setrecursionlimit(10**8)

def dfs1(current, parent): #top down
    subtreeSize[current] = 1
    for i in range(len(node[current])):
        child = node[current][i][0]
        weight = node[current][i][1]
        if child != parent:
            dfs1(child, current)
            distSum[current] += distSum[child] + subtreeSize[child] * weight
            subtreeSize[current] += subtreeSize[child]
    return

def dfs2(current, parent): #bottom up
    for i in range(len(node[current])):
        child = node[current][i][0]
        weight = node[current][i][1]
        if child != parent:
            distSum[child] = distSum[current] + weight * (n - 2 * subtreeSize[child])
            dfs2(child, current)
    return

n = int(sys.stdin.readline())
node = [[] for i in range(n + 1)]
subtreeSize = [0 for i in range(n + 1)]
distSum = [0 for i in range(n + 1)]
for i in range(n - 1):
    a, b, c = map(int, sys.stdin.readline().split())
    node[a].append([b, c])
    node[b].append([a, c])

dfs1(1, 1) #calculate sub tree size(갯수)
dfs2(1, 1)
for i in range(1, n + 1):
    print(distSum[i])
profile
I Don’t Hope. Just Do.

0개의 댓글