[ BOJ / Python ] 1238번 파티

황승환·2022년 3월 2일
0

Python

목록 보기
214/498


이번 문제는 다익스트라 알고리즘을 활용하여 해결하였다. 이 문제에서는 다익스트라 알고리즘을 함수로 따로 구현해야 한다. 여러번 사용해야 하기 때문이다. 결과적으로 모든 마을에서의 다익스트라 함수를 호출해야 한다. x마을을 출발점으로 호출한 경우에는 모든 마을의 비용에 해당 최소 비용을 각각 더해주어야 하고 (x마을에서 각자의 마을로 돌아가는 비용), 그 외의 마을을 출발점으로 호출한 경우에는 해당 마을의 비용에 x마을까지의 최소 비용을 더해주어야 한다(각자의 마을에서 x마을까지 가는 비용). 각 마을의 최종 비용이 저장된 리스트에서 가장 큰 값을 출력하면 문제에서 원하는 답을 출력할 수 있다.

  • n, m, x를 입력받는다.
  • 마을과 마을 간의 길을 저장할 2차원 리스트 graph를 선언한다.
  • m번 반복하는 for문을 돌린다.
    -> a, b, c를 입력받는다.
    -> graph[a][b, c]를 넣는다.
  • Dijkstra함수를 start를 인자로 갖도록 선언한다.
    -> q를 최소힙으로 선언한다.
    -> q에 [0, start]를 넣는다.
    -> 가장 큰 값을 INF 변수에 저장한다.
    -> 거리를 저장할 리스트 dist를 INF n+1개를 넣어 선언한다.
    -> q가 존재하는 동안 반복하는 while문을 돌린다.
    --> q에서 cost, cur을 추출한다.
    --> 만약 cost가 dist[cur]보다 클 경우, 다음 반복으로 넘어간다.
    --> graph[cur]을 순회하는 nxt, c에 대한 for문을 돌린다.
    ---> nxt_c를 cost+c로 저장한다.
    ---> 만약 dist[nxt]가 nxt_c보다 클 경우,
    ----> dist[nxt]를 nxt_c로 갱신한다.
    ----> q에 [nxt_c, nxt]를 넣는다.
    -> dist를 반환한다.
  • 정답을 저장할 리스트 answers를 0 n+1개로 채운다.
  • 1부터 n까지 반복하는 i에 대한 for문을 돌린다.
    -> 만약 i가 x와 같을 경우,
    --> 임시 변수 tmp에 Dijkstra(i)의 반환값을 저장한다.
    --> 1부터 n까지 반복하는 j에 대한 for문을 돌린다.
    ---> 만약 j가 x와 같을 경우, 다음 반복으로 넘어간다.
    ---> 그 외의 경우, answers[j]tmp[j]를 더한다.
    -> 그 외의 경우,
    --> 임시 변수 tmp에 Dijkstra(i)의 반환값을 저장한다.
    --> answers[i]tmp[x]의 값을 더한다.
  • answers의 최댓값을 출력한다.

Code

import heapq
import sys
n, m, x=map(int, input().split())
graph=[[] for _ in range(n+1)]
for _ in range(m):
    a, b, c=map(int, input().split())
    graph[a].append([b, c])
def Dijkstra(start):
    q=[]
    heapq.heappush(q, [0, start])
    INF=sys.maxsize
    dist=[INF for _ in range(n+1)]
    while q:
        cost, cur=heapq.heappop(q)
        if cost>dist[cur]:
            continue
        for nxt, c in graph[cur]:
            nxt_c=cost+c
            if dist[nxt]>nxt_c:
                dist[nxt]=nxt_c
                heapq.heappush(q, [nxt_c, nxt])
    return dist
answers=[0 for _ in range(n+1)]
for i in range(1, n+1):
    if i==x:
        tmp=Dijkstra(i)
        for j in range(1, n+1):
            if j==x:
                continue
            else:
                answers[j]+=tmp[j]
    else:
        tmp=Dijkstra(i)
        answers[i]+=tmp[x]
print(max(answers))

profile
꾸준함을 꿈꾸는 SW 전공 학부생의 개발 일기

0개의 댓글