백준 1504번: 특정한 최단 경로 [Python]

kimminjunnn·2025년 12월 30일

알고리즘

목록 보기
280/311

문제 출처 : https://www.acmicpc.net/problem/1504
난이도 : 골드 4


문제 파악

정점 1에서 정점 N까지 가는 “최단 경로”를 구하는데, 반드시 v1, v2를 지나야 한다.

즉, 가능한 경로는 딱 2가지 경우로 줄어든다.

  1. 1 -> v1 -> v2 -> N
  2. 1 -> v2 -> v1 -> N

이 두 경로의 길이를 각각 구해서 더 작은 값을 출력하면 된다.

단, 둘 다 불가능하면 -1.

해결 아이디어

각 경로는 “세 구간” 합으로 쪼갤 수 있다.

  • 1 -> v1
  • v1 -> v2
  • v2 -> N

그리고 반대 케이스도 마찬가지.

  • 1 -> v2
  • v2 -> v1
  • v1 -> N

여기서 핵심은 각 구간의 최단거리를 알아야 한다는 점.

그래서 다익스트라는 start -> 모든 노드 거리 배열을 뽑아주는 형태로 만들고,
딱 3번만 돌린다.

  • dist1 = dijkstra(1) : 1에서 모든 노드까지
  • distV1 = dijkstra(v1) : v1에서 모든 노드까지
  • distV2 = dijkstra(v2) : v2에서 모든 노드까지

이렇게 해두면 필요한 구간은 배열에서 꺼내서 더하기만 하면 끝이다.

해답 및 풀이

import sys
input = sys.stdin.readline
import heapq

INF = 1e9

N, E = map(int,input().split())
graph = [[] for _ in range(N+1)]

for _ in range(E):
    a, b, c = map(int,input().split())
    graph[a].append((b,c))
    graph[b].append((a,c))

# 반드시 거쳐야 하는 정점 v1, v2
v1, v2 = map(int,input().split())

# 1부터 N까지 가는 최단 경로를 구해야 하는데,
# 반드시 v1, v2를 거쳐야 한다.
#
# 가능한 경우는 2가지 뿐
# 1) 1 -> v1 -> v2 -> N
# 2) 1 -> v2 -> v1 -> N
# 각 경우의 거리 합을 구해서 더 작은 값을 출력한다.

# Dijkstra: start에서 모든 노드까지의 최단거리(dist 배열)를 채워서 반환
def Dijkstra(start):
    dist = [INF] * (N+1)
    dist[start] = 0

    # 우선순위 큐 (거리, 정점)
    queue = []
    heapq.heappush(queue, (0, start))

    while queue:
        cur_dist, cur_node = heapq.heappop(queue)

        # 이미 더 짧은 거리로 방문한 적이 있으면 무시
        if cur_dist > dist[cur_node]:
            continue

        for next_node, weight in graph[cur_node]:
            new_dist = cur_dist + weight

            if new_dist < dist[next_node]:
                dist[next_node] = new_dist
                heapq.heappush(queue, (new_dist, next_node))

    return dist

# 다익스트라는 3번만 돌리면 된다.
# 1에서의 거리들, v1에서의 거리들, v2에서의 거리들
dist1 = Dijkstra(1)
distV1 = Dijkstra(v1)
distV2 = Dijkstra(v2)

# 1) 1 -> v1 -> v2 -> N
v1Tov2 = dist1[v1] + distV1[v2] + distV2[N]

# 2) 1 -> v2 -> v1 -> N
v2Tov1 = dist1[v2] + distV2[v1] + distV1[N]

answer = min(v1Tov2, v2Tov1)

# 도달 불가(INF)가 섞이면 합이 INF보다 커질 수 있어서
# answer == INF 가 아니라 answer >= INF 로 체크해야 한다.
if answer >= INF:
    print(-1)
else:
    print(answer)
profile
Frontend Engineers

0개의 댓글