이번 문제는 다익스트라 알고리즘을 통해 해결하였다. 그래프는 양방향 인접 리스트로 구현하였고, 1부터 v1까지의 최단거리+v1부터 v2까지의 최단거리+v2부터 n까지의 최단거리와 1부터 v2까지의 최단거리+v2부터 v1까지의 최단거리+v1부터 n까지의 최단거리 중 더 작은 값을 결과로 출력하는 방식으로 접근하였다. 이 접근 방식을 그대로 사용하게 되면 6-1, 즉 5번의 다익스트라 알고리즘이 실행되게 된다. 이는 아무리 생각해도 비효율적이었지만 우선은 구현해보자는 생각으로 작성하였다. 당연하게도 시간초과가 계속해서 발생하였다. 이런 저런 방법을 생각해보다가 다른 사람의 코드를 보았다. 나와 다익스트라 알고리즘의 구성은 거의 유사했지만 다익스트라 함수의 반환 값이 리스트라는 점이 달랐다.
리스트 전체를 반환하여 이를 따로 저장하고, 다익스트라의 인자로는 시작 위치만 주어지게 된다. 이렇게 하면 시작 위치부터 n 사이의 모든 노드의 거리가 계산되어 리스트에 저장되므로 단 3번의 다익스트라 함수 호출로 정답을 구할 수 있다.
path=Dijkstra(1)
path1=Dijkstra(v1)
path2=Dijkstra(v2)
result=min(path[v1]+path1[v2]+path2[n], path[v2]+path2[v1]+path1[n])
여기서 path는 1을 시작점으로, path1은 v1을 시작점으로, path2는 v2를 시작점으로 하는 리스트를 저장하게 된다. 그리고 path[v1]은 1부터 v1까지의 최소거리, path1[v2]는 v1부터 v2까지의 최소거리, path2[n]은 v2부터 n까지의 최소거리가 되므로 처음에 접근하고자 했던 방식과 같다.
이렇게 함수 호출의 횟수를 리스트 반환을 통해 5번에서 3번으로 줄여 해결하였다. 이렇게 수정한 코드에서도 시간초과가 발생하여 input을 sys.stdin.readline으로 변경해주었더니 성공하였다.
sys.stdin.readline
으로 선언한다.graph[a]
에 (b, c)
를 넣는다.graph[b]
에 (a, c)
를 넣는다.sys.maxsize
를 저장한다.dist[start]
를 0으로 갱신한다.(0, start)
를 넣어준다.dist[cur]
보다 클 경우 다음 반복으로 넘어간다.graph[cur]
을 순회하는 nxt, dst에 대한 for문을 돌린다.distance+dst
를 저장한다.dist[nxt]
가 cost보다 클 경우,dist[nxt]
를 cost로 갱신한다.(cost, nxt)
를 넣는다.Dijkstra(1)
의 반환 리스트를 저장한다.Dijkstra(v1)
의 반환 리스트를 저장한다.Dijkstra(v2)
의 반환 리스트를 저장한다.path[v1]+path1[v2]+path2[n]
과 path[v2]+path2[v1]+path1[n]
중 더 작은 값을 저장한다.import heapq
import sys
input=sys.stdin.readline
n, e=map(int, input().split())
graph=[[] for _ in range(n+1)]
for _ in range(e):
a, b, c=map(int, input().split())
graph[a].append((b, c))
graph[b].append((a, c))
v1, v2=map(int, input().split())
INF=sys.maxsize
def Dijkstra(start):
dist=[INF for _ in range(n+1)]
dist[start]=0
q=[]
heapq.heappush(q, (0, start))
while q:
distance, cur=heapq.heappop(q)
if distance>dist[cur]:
continue
for nxt, dst in graph[cur]:
cost=distance+dst
if cost<dist[nxt]:
dist[nxt]=cost
heapq.heappush(q, (cost, nxt))
return dist
path=Dijkstra(1)
path1=Dijkstra(v1)
path2=Dijkstra(v2)
result=min(path[v1]+path1[v2]+path2[n], path[v2]+path2[v1]+path1[n])
if result<INF:
print(result)
else:
print(-1)