BOJ 23840 - 두 단계 최단 경로 4 링크
(2023.04.20 기준 P5)
N개의 정점과 양방향 간선 M개가 주어진다. 출발 정점 X에서 출발해서 P개의 중간 정점 모두를 거친 후 도착 정점 Z에 도달하는 최단 거리 출력
N은 최대 100,000이지만 P는 최대 20이다. P개의 정점을 모두 방문하는 최단 거리를 외판원 순회로 구하면 O(20 ** 2 * 2 ** 20) = 약 4초.
외판원 순회와 최종 결과를 위해 중간 정점과 시작 정점, 도착 정점. 최대 22개의 정점 간 거리를 구해야 한다. 다익스트라로 구하면 O(100000 * log 100000) * 22 = 약 0.25초. 이 문제의 시간 제한은 7초이므로 이 방법이 충분한 풀이가 될 것이다.
이 문제는 모든 정점을 방문할 필요가 없다. 출발 정점 X, 도착 정점 Z, 그리고 주어지는 P개의 중간 정점들을 방문해야 한다.
하지만 X와 Z는 처음과 마지막으로 고정되어 있다. 순서가 바뀔 수 있는 것은 중간 정점들.
그러면 간단하다. P개의 중간 정점들만 외판원 순회를 돌리되, 마지막에 첫번째로 돌아가는 것이 아니라 도착 정점으로 가면 되는 것이다.func TSP(현재, 방문한 정점들): # 외판원 순회 if (방문한 정점들 == (1 << P) - 1): # P개의 중간 정점을 모두 방문하면 return cost[현재][첫번째] -> X return cost[현재][도착 정점] -> O # 도착 정점으로 가면 된다.
그리고 처음에 외판원 순회를 시작할 때에는 시작 정점부터 시작하면 된다. 당연히 방문한 정점들은 0으로 시작하는 것이다.
그리고 P개의 중간 정점들로만 방문하게 하면 되는 것이다.이렇게 하려면 중간 정점들, 시작 정점, 도착 정점. 총 최대 22개의 정점 간 거리가 필요하다.
이는 그냥 다익스트라를 최대 22번 돌려서 거리를 구해놓으면 된다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef priority_queue<pll, vector<pll>, greater<pll>> heapq;
const ll inf = 1e16;
int P, Y[22];
ll cost[22][22], dp[21][1 << 20];
vector<pll> graph[100000];
void dijkstra(vector<ll> &distance, int root){ // 다익스트라
heapq pq;
pq.push({0, root});
distance[root] = 0;
while (!pq.empty()){
pll here = pq.top(); pq.pop();
if (distance[here.second] < here.first) continue;
for (pll there: graph[here.second]) if (distance[there.first] > here.first + there.second)
distance[there.first] = here.first + there.second, pq.push({distance[there.first], there.first});
}
}
ll dfs(int here, int visited){ // 외판원 순회
if (visited == (1 << P) - 1) // P개의 중간 정점을 모두 방문하면
return cost[here][P + 1]; // 도착 정점인 P+1번으로 가면 된다.
if (dp[here][visited] > -1)
return dp[here][visited];
dp[here][visited] = inf;
for (int there = 0; there < P; there++)
if (!(visited & (1 << there)) && cost[here][there] < inf)
dp[here][visited] = min(dp[here][visited], dfs(there, visited | (1 << there)) + cost[here][there]);
return dp[here][visited];
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
int N, M;
cin >> N >> M;
for (int i = 0, u, v, w; i < M; i++){
cin >> u >> v >> w;
graph[--u].push_back({--v, w});
graph[v].push_back({u, w});
}
int X, Z;
cin >> X >> Z;
cin >> P;
for (int i = 0; i < P; i++) cin >> Y[i], Y[i]--;
Y[P] = --X, Y[P + 1] = --Z; // 출발 정점은 P번, 도착 정점은 P+1번이 된다.
// P+2개의 정점 간의 거리를 다익스트라로 구해주자.
vector<ll> distance(N);
for (int i = 0; i < P + 2; i++){
fill(distance.begin(), distance.end(), inf);
dijkstra(distance, Y[i]);
for (int j = 0; j < P + 2; j++) cost[i][j] = distance[Y[j]];
}
// 출발지를 출발 정점인 P번으로 하여금 외판원 순회
memset(dp, -1, sizeof(dp));
ll result = dfs(P, 0);
if (result < inf) cout << result;
else cout << -1;
}
import sys; input = sys.stdin.readline
from math import inf
from heapq import heappop, heappush
def dijkstra(distance, root): # 다익스트라
queue = [(0, root)]
distance[root] = 0
while queue:
d, here = heappop(queue)
if distance[here] < d:
continue
for there, dd in graph[here]:
if distance[there] > d + dd:
distance[there] = d + dd
heappush(queue, (distance[there], there))
def dfs(here, visited): # 외판원 순회
if visited == (1 << P) - 1: # P개의 중간 정점을 모두 방문하면
return cost[here][P + 1] # 도착 정점인 P+1번으로 가면 된다.
if dp[here][visited] > -1:
return dp[here][visited]
dp[here][visited] = inf
for there in range(P):
if not visited & (1 << there) and cost[here][there] < inf:
dp[here][visited] = min(dp[here][visited], dfs(there, visited | (1 << there)) + cost[here][there])
return dp[here][visited]
N, M = map(int, input().split())
graph = [[] for _ in range(N)]
for _ in range(M):
u, v, w = map(int, input().split())
u -= 1; v -= 1
graph[u].append((v, w))
graph[v].append((u, w))
X, Z = map(int, input().split())
X -= 1; Z -= 1
P = int(input())
Y = list(map(lambda x: int(x) - 1, input().split()))
Y.append(X); Y.append(Z) # 출발 정점은 P번, 도착 정점은 P+1번이 된다.
# P+2개의 정점 간의 거리를 다익스트라로 구해주자.
cost = [[0] * (P + 2) for _ in range(P + 2)]
for i in range(P + 2):
distance = [inf] * N
dijkstra(distance, Y[i])
for j in range(P + 2):
cost[i][j] = distance[Y[j]]
# 출발지를 출발 정점인 P번으로 하여금 외판원 순회
dp = [[-1] * (1 << P) for _ in range(P + 1)]
result = dfs(P, 0)
print(result) if result < inf else print(-1)