매일 아침, 세준이는 학교에 가기 위해서 차를 타고 D킬로미터 길이의 고속도로를 지난다. 이 고속도로는 심각하게 커브가 많아서 정말 운전하기도 힘들다. 어느 날, 세준이는 이 고속도로에 지름길이 존재한다는 것을 알게 되었다. 모든 지름길은 일방통행이고, 고속도로를 역주행할 수는 없다.
세준이가 운전해야 하는 거리의 최솟값을 출력하시오.
첫째 줄에 지름길의 개수 N과 고속도로의 길이 D가 주어진다. N은 12 이하인 양의 정수이고, D는 10,000보다 작거나 같은 자연수이다. 다음 N개의 줄에 지름길의 시작 위치, 도착 위치, 지름길의 길이가 주어진다. 모든 위치와 길이는 10,000보다 작거나 같은 음이 아닌 정수이다. 지름길의 시작 위치는 도착 위치보다 작다.
세준이가 운전해야하는 거리의 최솟값을 출력하시오.
DP풀이
방향과 가중치가 있기 때문에 다익스트라로 해결할 수 있지만, DP로 해결해보았다
문제 풀면서 주의해야할 부분들이 있는데 첫 번째로, 지름길이 항상 지름길은 아니라는 것이다
10 90 100 처럼 지름길을 통해서 갈 때 거리가 더 길어지는 경우가 존재한다
따라서 나는 지름길 정보를 입력받을 때 이러한 경우는 제외하고 리스트에 담았다
그리고 지름길을 통해서 도착할 수 있는 곳이라면 지름길의 시작점까지 도달하는 거리에 지름길의 거리를 더해주어서 dp 리스트를 갱신해야한다
또한, 그렇게 최소거리로 갱신되었을 경우 그 다음부터 이어지는 거리들도 갱신해야한다
예를 들어, 0 5 1 이 주어진다면 행이 거리를 의미할 때
| D | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 2 | 3 | 4 | 1 | 2 | 3 | 4 | 5 |
의 형태로 거리들이 갱신되어야 한다 그리고 이 거리들은 지름길들에 따라 계속해서 바뀔 수 있기 때문에 지름길의 도착지점에 해당하지 않는 지점들은 이전 지점까지의 거리에 +1한 것과 현재 지점까지의 거리를 비교해 더 작은 값을 취해야한다
코드
import sys
n, d = map(int, sys.stdin.readline().split())
dp = [i for i in range(d+1)]
shortcuts = []
for _ in range(n):
start, dest, length = map(int, sys.stdin.readline().split())
if dest - start > length:
shortcuts.append((start, dest, length))
shortcuts.sort()
for start, dest, length in shortcuts:
for i in range(1, d+1):
if dest == i:
dp[i] = min(dp[i], dp[start]+length)
else:
dp[i] = min(dp[i], dp[i-1]+1)
print(dp[d])
다익스트라 풀이
heapq를 사용해서 다익스트라로 풀어보았다
0부터 d까지 모든 수를 하나의 노드라고 생각하고 각각의 수들이 갖는 가중치를 1이라고 가정한다
방향은 0->1, 1->2 와 같이 자신의 다음에 오는 수를 향한다
다익스트라를 사용할 때는 도착 지점과 출발 지점의 거리 차이와 지름길 길이의 관계를 고려하지 않아도 된다
다만, 도착 지점이 d보다 큰 경우는 제외해야한다
시작 노드인 0의 cost를 0으로 가정하고 heap에 push한다
이후 시작 노드와 연결된 다른 노드들을 하나하나 꺼내어 시작노드를 거쳐서 연결된 노드까지 가는 경우와 단숨에 연결된 노드까지 가는 경우를 비교해 더 cost가 작은 것을 선택한다
코드
import sys, heapq
n, d = map(int, sys.stdin.readline().split())
inf = float("inf")
graph = [[] for _ in range(d+1)]
dist = [inf]*(d+1)
for i in range(d):
graph[i].append((i+1, 1))
for _ in range(n):
start, dest, length = map(int, sys.stdin.readline().split())
if dest<=d:
graph[start].append((dest, length))
q = []
heapq.heappush(q, (0,0))
dist[0] = 0
while q:
w1, u = heapq.heappop(q)
for v, w2 in graph[u]:
cost = dist[u] + w2
if dist[v] > cost:
dist[v] = cost
heapq.heappush(q, (cost, v))
print(dist[d])