Problem link: https://www.acmicpc.net/problem/13392
Top-down DP 문제로, 역시나 내가 가장 취약한 부분답게 꽤 고전했다.
Top-down DP 풀이를 바로 기술하는 것은 취약점 개선에 아무런 도움이 되지 않으므로, 최종 답안까지 가는 아이디어를 하나하나 차근차근 기술해보도록 하겠다.
문제를 잘 관찰해보면 아래와 같은 사실들을 알 수 있다.
이 정도까지 파악하면 적어도 Brute-force 풀이는 아래 절과 같이 쉽게 찾아낼 수 있다.
관찰을 통해 이 문제는 위에 위치한 숫자 나사부터 차례로 아래로 내려가며 풀어주어야 한다는 사실을 알 수 있다.
따라서, 적어도 아래와 같은 단순한 풀이 정도는 떠올릴 수 있을 것이다.
def Solve(idx, Screws& screws): # idx번째 screw부터 끝까지 숫자를 맞출 때의 최소 회전 수를 반환
lmove = delta_l(screws, target, idx) # idx번째 screw를 좌회전을 통해 target(목표)까지 맞출 때의 이동량
lscrews = turn_left(screws, idx, lmove) # screw에 좌회전을 준다(idx 및 보다 밑에 위치하는 숫자 나사가 다 돌아간다)
l = lmove + Solve(idx + 1, lscrews) # 재귀
rmove = delta_r(screws, target, idx) # idx번째 screw를 우회전을 통해 target(목표)까지 맞출 때의 이동량
rscrews = turn_right(screws, idx, rmove) # screw에 우회전을 준다(idx에 위치하는 숫자 나사가 돌아간다)
r = rmove + Solve(idx + 1, rscrew) # 재귀
return min(l, r)
입력의 크기만 크지 않았다면 크게 나쁘지 않는 솔루션이지만, 우리의 입력 범위를 생각해보면 시간 내에 풀릴리가 없다.
그럼, 이제 이를 개선할 방법을 떠올려보자.
딱 봐도 중복되는 부분문제를 너무 많이 풀고 있지 않은가?
예를 들어, 위의 재귀 진행과정에서 idx=1, 2, 3
번째의 숫자 나사들을 각각 좌3, 우1, 우1
회전하였다고 가정해보자.
이때, idx=4
번째 나사는 초기 상태에서 3칸 왼쪽으로 돌아가 있는 상태일 것이다(보다 위에 있는 숫자 나사들의 좌회전 수 합이 3이므로).
이번에는, idx=1, 2, 3
번째의 숫자 나사들을 각각 좌1, 좌1, 좌1
회전하였다고 가정해보자.
이때에도 idx=4
번째 나사는 초기 상태에서 3칸 왼쪽으로 돌아가 있는 상태일 것이다.
상술한 두 가지의 경우는 idx=4
번째 나사의 위치가 동일하므로 동일한 답을 내는 동일한 부분문제일 수 밖에 없음에도 Brute-force 방법은 중복되는 문제를 여러번 풀게 된다.
그럼, 이제 위에서 말한 아이디어를 기초로 DP 풀이를 설계해보자.
Brute-force 풀이에서 얻을 수 있었던 교훈은 결국 idx
번째 숫자 나사를 고려할 때 중요한 것은 idx-1
번째 숫자 나사까지 좌회전 수의 합이 얼마냐는 사실이었다.
따라서, 아래와 같이 DP Cache를 정의한다(개폐구간을 눈여겨보자).
CACHE[idx][left_turn]
: [1, idx)
번째 숫자 나사까지는 이미 다 맞춰져 있고, 그때까지 사용한 좌회전 수의 합이 left_turn
개 일 때, [idx, N]
번째 숫자 나사를 맞추기 위한 최소 회전 횟수점화식은 사실 CACHE를 활용해서 메모이제이션 한다뿐이지 Brute-force와 동일하므로 생략한다.
좌회전 수의 합은 숫자 나사가 원형이므로 mod 10
연산이 가능하므로 실제 CACHE의 크기는 N*10
정도가 될 것이다.
보통 DP의 복잡도는 존재하는 부분분제의 수 x 각 부분 문제의 루프수
로 산정하니 이 문제 복잡도는 O(N)
이 될 것이다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int kMaxN = 10000;
int N;
int S[kMaxN + 1];
int T[kMaxN + 1];
int CACHE[kMaxN + 1][10];
int Solve(const int screw_idx, const int left_turn)
{
if (screw_idx > N)
{
return 0;
}
int& ret = CACHE[screw_idx][left_turn];
if (ret != -1)
{
return ret;
}
int cur = (S[screw_idx] + left_turn) % 10;
int lmove = (T[screw_idx] - cur < 0) ? T[screw_idx] - cur + 10 : T[screw_idx] - cur;
int rmove = (cur - T[screw_idx] < 0) ? cur - T[screw_idx] + 10 : cur - T[screw_idx];
ret = min(rmove + Solve(screw_idx + 1, left_turn), lmove + Solve(screw_idx + 1, (left_turn + lmove) % 10));
return ret;
}
int main(void)
{
// Initialize
for (int i = 0; i < kMaxN + 1; ++i)
{
for (int j = 0; j < 10; ++j)
{
CACHE[i][j] = -1;
}
}
// Read Inputs
scanf(" %d", &N);
for (int idx = 1; idx <= N; ++idx)
{
scanf(" %1d", &S[idx]);
}
for (int idx = 1; idx <= N; ++idx)
{
scanf(" %1d", &T[idx]);
}
// Solve
printf("%d\n", Solve(1, 0));
return 0;
}