한 지점에 대해 분기되는 케이스는 아래와 같다.
A)
1
방향 진입 후0
방향 진출
B)1
방향 진입 후1
방향 진출
C)0
방향 진입 후0
방향 진출
D)0
방향 진입 후1
방향 진출
즉, A와 D의 경우 방향 전환이 일어났음을 알 수 있다.
또한 지점에서 지점으로 이동할 때의 시간은 다음과 같이 구해진다.
1. 가로와 세로의 이동 횟수 번으로 동일하다.
2. 방향 전환 횟수만큼 1이 증가한다.
즉, 이동 시간은 (방향 전환 횟수) 이다.
dp 배열에 저장할 정보와 값 아래와 같다.
dp
[x 좌표]
[y 좌표]
[방향 전환 횟수]
[진입 방향(0 or 1)]
= (사용한 연료)
void caseInput()
{
cin >> n >> m >> l >> g;
for (int i = 0; i < n; i++)
for (int j = 1; j < m; j++)
cin >> cost[i][j].first;
for (int i = 1; i < n; i++)
for (int j = 0; j < m; j++)
cin >> cost[i][j].second;
}
연료 소모량 입력
first
에는 0 방향으로 진출할 때의 소모량,
second
에는 1 방향으로 진출할 때의 소모량을 입력받는다.
void init()
{
for (int i = 0; i < 100; i++)
for (int j = 0; j < 100; j++)
for (int x = 0; x < 200; x++)
for (int y = 0; y < 2; y++)
dp[i][j][x][y] = 2e9;
dp[0][0][0][0] = dp[0][0][0][1] = 0;
}
초기화 함수
출발 지점을 제외하고 큰 값으로 초기화한 뒤, 최소 연료 소모량을 갱신해간다.
//dp
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
for (int turn = 0; turn <= i + j; turn++)
{
if (i + 1 < n)
{
//right -> down
dp[i + 1][j][turn + 1][1] = min(dp[i + 1][j][turn + 1][1],
dp[i][j][turn][0] + cost[i + 1][j].second);
//down -> down
dp[i + 1][j][turn][1] = min(dp[i + 1][j][turn][1],
dp[i][j][turn][1] + cost[i + 1][j].second);
}
if (j + 1 < m)
{
//right -> right
dp[i][j + 1][turn][0] = min(dp[i][j + 1][turn][0],
dp[i][j][turn][0] + cost[i][j + 1].first);
//down -> right
dp[i][j + 1][turn + 1][0] = min(dp[i][j + 1][turn + 1][0],
dp[i][j][turn][1] + cost[i][j + 1].first);
}
}
dp 배열 갱신
위에서 설명한 경우에 따라 이동할 지점의 값을 갱신한다.
방향 전환 시 turn 값을 1 증가시키고, 최소 연료 소모량을 갱신한다.
//get answer
int minTurn = 200;
for (int i = 1; i <= n + m; i++)
if (dp[n - 1][m - 1][i][0] <= g || dp[n - 1][m - 1][i][1] <= g)
{
minTurn = min(minTurn, i);
break;
}
(minTurn == 200) ? cout << -1 : cout << minTurn + l * (n + m - 2);
cout << '\n';
답안 출력
방향 전환 횟수를 큰 값으로 초기화 한 뒤, 좌표에 G 이하의 연료 소모량으로 도달할 수 있다면, 그때의 방향 전환 횟수가 주행 시간을 결정한다.
#include <iostream>
#include <algorithm>
using namespace std;
#define IAMFAST ios_base::sync_with_stdio(false);cin.tie(0);
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<long long, long long> pll;
int t;
int n, m, l, g;
int dp[101][101][201][2];
pii cost[101][101];//{go right, go down}
void INPUT()
{
IAMFAST
cin >> t;
}
void caseInput()
{
cin >> n >> m >> l >> g;
for (int i = 0; i < n; i++)
for (int j = 1; j < m; j++)
cin >> cost[i][j].first;
for (int i = 1; i < n; i++)
for (int j = 0; j < m; j++)
cin >> cost[i][j].second;
}
void init()
{
for (int i = 0; i < 100; i++)
for (int j = 0; j < 100; j++)
for (int x = 0; x < 200; x++)
for (int y = 0; y < 2; y++)
dp[i][j][x][y] = 2e9;
dp[0][0][0][0] = dp[0][0][0][1] = 0;
}
void solution()
{
while (t--)
{
caseInput();
init();
//dp
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
for (int turn = 0; turn <= i + j; turn++)
{
if (i + 1 < n)
{
//right -> down
dp[i + 1][j][turn + 1][1] = min(dp[i + 1][j][turn + 1][1],
dp[i][j][turn][0] + cost[i + 1][j].second);
//down -> down
dp[i + 1][j][turn][1] = min(dp[i + 1][j][turn][1],
dp[i][j][turn][1] + cost[i + 1][j].second);
}
if (j + 1 < m)
{
//right -> right
dp[i][j + 1][turn][0] = min(dp[i][j + 1][turn][0],
dp[i][j][turn][0] + cost[i][j + 1].first);
//down -> right
dp[i][j + 1][turn + 1][0] = min(dp[i][j + 1][turn + 1][0],
dp[i][j][turn][1] + cost[i][j + 1].first);
}
}
//get answer
int minTurn = 200;
for (int i = 1; i < 200; i++)
if (dp[n - 1][m - 1][i][0] <= g || dp[n - 1][m - 1][i][1] <= g)
{
minTurn = min(minTurn, i);
break;
}
(minTurn == 200) ? cout << -1 : cout << minTurn + l * (n + m - 2);
cout << '\n';
}
}
int main()
{
INPUT();
solution();
}
세상에서 제일 난잡해지는 머리와 다르게 너무나 명료한 코드..