다익스트라 -바킹독 유튜브

코변·2022년 11월 20일
0

알고리즘 공부

목록 보기
45/65

알고리즘 설명

하나의 시작점으로 하나의 시작점으로부터 다른 모든 정점까지의 최단거리를 구하는 알고리즘(음수 간선이 있다면사용할 수 없다.)

플로이드 알고리즘을 떠올려보면 모든 정점 쌍 사이의 최단거리를 구하는 알고리즘이고 음수 간선에도 사용할 수 있기 때문에 기능을 플로이드가 더 좋다고 할 수 있으나 시간에서 더 효율을 보인다.

다익스트라는 학부 알고리즘 수업에서 반드시 다루는 알고리즘이기도 하고 네트워크 같은데에서 쓰이기도 하기 때문에 이름과 목적은 유명하다.

그러나 구현에서 실수를 할 수 있는 여지가 굉장히 많아서 다익스트라에 대해서 확실히 모르거나 구현에 자신이 없다면 확실히 개념을 잡는 것이 좋다.

구현

최단거리 테이블을 모두 무한대로 두고 자기 자신과의 거리만 0으로 초기화 해준다. 매 단계마다 도달할 수 있는 정점중에서 시작점으로부터의 거리가 가장 가까운 정점을 구해서 그 거리를 확정하는 방식으로 동작한다.

모든 정점과 연결된 정점의 목록을 다 살피면서 최단거리 테이블을 갱신해나가면 O(VE)가 된다. 반면 새 정점을 추가할 때마다 미리 테이블에 거리를 계산해두고 거기서 최솟값을 찾는 방식으로 하면 O(V2 + E)가 되고 일반적으로 E가 V보다 크기 때문에 O(VE)보다 O(V2+E)가 효율적이다.

미리 테이블에 거리를 계산한다라는 말의 의미는 정점을 순회할 때 마다 그 때의 최단거리를 테이블에 일단 써놓고 제일 짧은거리 정점 순으로 순회하면서 최단거리를 갱신해준다.

바킹독님이 설명해주신대로 구현해서 만든 코드 (틀릴 수도 있다.)

N = 6
M = 8
INF = int(10e6)
distances = [INF] * (N+1)
visited = [0] * (N+1)
distances[1] = 0
visited[1] = 1

graph = [[ ] for _ in range(N+1)]
for _ in range(M):
    a, b, c = map(int, input().split())
    
    graph[a].append((b, c))

queue = deque()
queue.append(1)

while queue:
    cur_node = queue.popleft()

    for node in graph[cur_node]:
        distances[node[0]] = min(distances[cur_node] + node[1], distances[node[0]])

    min_node = -1
    min_distance = INF
    for i in range(2, len(distances)):
        if min_distance > distances[i] and not visited[i]:
            min_distance = distances[i]
            min_node = i
    if min_node == -1: break

    visited[min_node] = 1
    queue.append(min_node)


print(distances)

바킹독님이 쓰신 예제

while True:
    idx = -1
    for i in range(1, N+1):
        if visited[i]: continue
        if idx == -1: idx = i
        elif (distances[i] < distances[idx]): idx = i

    print(idx)
    if idx == -1 or distances[idx] == INF:
        break
    visited[idx] = 1
    for next_node in graph[idx]:
        distances[next_node[0]] = min(distances[next_node[0]],distances[idx] + next_node[1] )


print(distances)

다익스트라 알고리즘은 매번 아직 거리가 확정되지 않은 정점들 중에서 가장 가까운 정점을 찾아서 확정한다. 쉽게 말하면 이것도 일종의 그리디 알고리즘이라고 할 수 있고 매번 확정된 정점까지 최단 거리가 우리가 구한 값이라는 걸 증명할 수 있다면 다익스트라가 동작한다는 것을 납득할 수 있다.

단순하게 생각해보면 시작점이 1번정점이라고 생각했을 때 1번정점과 가장 가까운 값을 선택한다고 생각하고 그 정점이 아닌 다른 정점으로 우회해서 그 정점에 더 빠르게 갈 수 있는 곳이 있다고 가정해보면 사실상 그럴 수가 없다는 것을 알 수있다. 다른 정점들은 그 정점으로 가는 거리보다 멀기 때문에 1 보다 3이 작다 같은 말도 안되는 식이 발생하고 결국 정점마다 거리의 최소값을 확정하는 알고리즘은 동작한다는 사실을 알 수 있다.

이 예시를 보면 다익스트라가 왜 음수간선을 처리할 수 없는지도 설명이 가능하다. 지금 현재 가장 짧은 길을 선택해서 나아가는데 나중에 음수간선으로 인해서 길이가 줄어드는 거리가 발생한다면 이미 확정된 값보다 더 짧은 거리가 발생할 수 있기 때문에 그렇다.

다익스트라님이 처음 알고리즘을 만들었을 때는 지금 만든 O(V**2 +E)에 동작하는 알고리즘이었으나 개선의 여지가 있다.

다익스트라 알고리즘은 매번 가장 가까운 정점을 뽑아내야 한다. 원소의 추가가 가능하고 최솟값의 확인/삭제가 효율적이 자료구조인 힙, 우선순위 큐를 활용하면 알고리즘의 속도가 더 빨라진다.

우선순위 큐를 이용한 다익스트라 알고리즘

  1. 우선순위 큐에 (0, 시작점)을 추가
  2. 우선순위 큐에서 거리가 가장 작은 원소를 선택, 해당거리가 최단거리 테이블에 잇는 값과 다를 경우 넘어감
  3. 원소가 가리키는 정점을 v라고 할 때 v와 이웃한 정점들에 대해 최단거리 테이블 값보다 v를 거쳐가는 것이 더 작은 값을 가질 경우 최단거리 테이블 값을 갱신하고 우선순위 큐에 추가
  4. 우선순위 큐가 빌 때꺄지 2, 3번 과정을 반복

시간복잡도가 헷갈릴 수 있는데 우선순위 큐에 추가될 수 있는 원소의 수를 생각해보면 간선 1개당 최대 1개의 원소가 추가될 수 있기 때문에 O(ElogE)라고 할 수 있다.

boj-1753

import heapq
import sys
input = sys.stdin.readline

V, E = map(int, input().split())
start = int(input())
INF = int(10e9)
graph = [[ ] for _ in range(V+1)]
distances = [INF] * (V+1)
distances[start] = 0
for _ in range(E):
    u, v, weight = map(int, input().split())
    
    graph[u].append((weight, v))
    
priority_queue = []
heapq.heappush(priority_queue, (0, start))

while priority_queue:
    weight, cur_node = heapq.heappop(priority_queue)
    
    for next_weight, next_node in graph[cur_node]:
        if weight + next_weight < distances[next_node]:
            distances[next_node] = weight + next_weight
            heapq.heappush(priority_queue, (weight + next_weight, next_node))

print(distances)
for distance in distances[1:]:
    if distance == INF:print("INF")
    else: print(distance)

경로 복원 방법

지금 현재 경로로 오기전 노드를 최단거리 테이블을 갱신할 때마다 pre테이블에 저장해준다. 플로이드알고리즘에서 경로를 확인했던 것처럼 반대로 돌아가면서 확인하면 경로를 확인할 수 있다.

boj-11779

import heapq
import sys
input = sys.stdin.readline

N = int(input())
M = int(input())
graph = [[ ] for _ in range(N+1)]
for _ in range(M):
    u, v, cost = map(int, input().split())
    graph[u].append((cost, v))
    
start, end = map(int, input().split())
INF = int(10e9)
distances = [INF] *(N+1)
possible_road = []
pre_road = [0]* (N+1)
distances[start] = 0
heapq.heappush(possible_road, (0, start))

while possible_road:
    cur_weight, cur_node = heapq.heappop(possible_road)
    if distances[cur_node] != cur_weight: continue
    
    for next_weight, next_node in graph[cur_node]:
        next_distance = cur_weight + next_weight
        if next_distance >= distances[next_node]: continue
        distances[next_node] = next_distance
        heapq.heappush(possible_road, (next_distance, next_node))
        pre_road[next_node] = cur_node
            
print(distances[end])
cur_idx = end
road_stack = []
while cur_idx != start:
    road_stack.append(cur_idx)
    cur_idx = pre_road[cur_idx]
road_stack.append(cur_idx)
print(len(road_stack))
for road in reversed(road_stack):
    print(road, end=' ')
profile
내 것인 줄 알았으나 받은 모든 것이 선물이었다.

0개의 댓글