백준 / 13549 / 숨바꼭질 3 / python

맹민재·2023년 5월 11일
0

알고리즘

목록 보기
90/134


처음 시도 코드

from heapq import heappop, heappush

n, k = [int(v) for v in input().split()]

dist = [1e9] * (k+2)
dist[n] = 0

h = []
heappush(h, (0, n))

while h:
    dis, node = heappop(h)
    if dist[node] < dis:
        continue

    if node + 1 <= l and dis + 1 < dist[node+1]:
        dist[node+1] = dis+1
        heappush(h, (dis+1, node+1))
    if node - 1 >= 0 and dis + 1 < dist[node-1]:
        dist[node-1] = dis+1
        heappush(h, (dis+1, node-1))
    if node * 2 <= l and dis < dist[node*2]:
        dist[node*2] = dis
        heappush(h, (dis, node*2))

print(dist[k])

우선 다익스트라 알고리즘으로 해결이 가능해 보여서 다익스트라 알고리즘을 적용해서 풀었었다. 주어진 예제에 대해서는 맞았지만 제출 했을 때 indexError가 났다.

이유를 생각해보니 k보다 n이 클 때 당연히 indexError가 난다.

수정

from heapq import heappop, heappush

n, k = [int(v) for v in input().split()]

l = k if k > n else n

dist = [1e9] * (l+1)
dist[n] = 0

h = []
heappush(h, (0, n))

while h:
    dis, node = heappop(h)
    if dist[node] < dis:
        continue

    if node + 1 <= l and dis + 1 < dist[node+1]:
        dist[node+1] = dis+1
        heappush(h, (dis+1, node+1))
    if node - 1 >= 0 and dis + 1 < dist[node-1]:
        dist[node-1] = dis+1
        heappush(h, (dis+1, node-1))
    if node * 2 <= l and dis < dist[node*2]:
        dist[node*2] = dis
        heappush(h, (dis, node*2))

print(dist[k])

이번에는 IndexError는 피했지만 제출했을 때 틀렸다고 뜬다.
다시 반례를 찾다가 틀린 이유 발견
*2로 진행 후 -1을 했을 때가 최단인 경우이다.
예를 들면 4 -> 7 인 경우 4 -> 8 -> 7 이렇게가 최단 경로인데 dist 배열을 k까지만 만들게 되면 해당 경우를 탐색할 수 없다.
다시 수정

최종 코드

from heapq import heappop, heappush

n, k = [int(v) for v in input().split()]

l = k if k > n else n

dist = [1e9] * (l+2)
dist[n] = 0

h = []
heappush(h, (0, n))

while h:
    dis, node = heappop(h)
    if dist[node] < dis:
        continue

    if node + 1 <= l and dis + 1 < dist[node+1]:
        dist[node+1] = dis+1
        heappush(h, (dis+1, node+1))
    if node - 1 >= 0 and dis + 1 < dist[node-1]:
        dist[node-1] = dis+1
        heappush(h, (dis+1, node-1))
    if node * 2 <= l+1 and dis < dist[node*2]:
        dist[node*2] = dis
        heappush(h, (dis, node*2))

print(dist[k])

드디어 해결되었다.


사실 길이를 입력에 따라 바꾸는 방식말고 최대 크기가 10만이기 때문에 list의 크기를 100001로 고정해도 충분히 풀 수 있는 문제다.
이런걸 삽질이라고 해야하나....
그래도 덕분에 문제를 풀 때 다양한 반례를 생각할 수 있었던 좋은 경험이였다.

profile
ㄱH ㅂrㄹ ㅈr

0개의 댓글