문제 링크
https://www.acmicpc.net/problem/12850
일단 N이 10억씩이나 돼서 O(N)의 복잡도보다 더 좋은 것을 생각해내야만 했다.
문제를 분할해서 O(logN)의 복잡도를 가지는 알고리즘을 생각하려고 노력했다.
일단 거대한 N을 쪼개기 위해서 반으로 계속 나누어서 분할정복을 하기로 했고, 함수 f를 다음과 같이 정의했다.
f(d,frm,to)
= frm을 시작점, to를 끝점으로 하는, 거리가 d인 경로의 개수
그리하여 짠 코드는 다음과 같다.
def f(d, frm, to):
if d == 1:
return dist[frm][to] # 연결 돼있으면 1, 아니면 0
half = d // 2
other = half + 1 if d % 2 else half # 홀수면 +1
# half <= other
sum = 0
for k in range(N):
sum += f(half, frm, k) * f(other, k, to)
sum %= MOD
return sum
어차피 우리는 처음 위치로 돌아와야 한다. 그래서 처음 함수를 호출할 때 f(D,0,0)
의 형태로 호출을 하게된다.
우리는 이 D를 절반의 경로 2개로 쪼갤 것이다. (홀수면 다른 하나가 1 크죠)
하지만 궁극적인 우리의 목표, 시작점과 끝점이 둘 다 0이긴 해야할 것이다. 그러면 쪼개진 경계점을 어떻게 설정해야할지 감이 잡힐 것이다.
sum = 0
for k in range(N):
sum += f(half, frm, k) * f(other, k, to)
sum %= MOD
이런식으로 점 k를 경유해서 간다고 생각하면 된다.
D == half + other
이고
half
의 경로를 갖는, 0에서 시작하고 k에서 끝나는 경로들과
other
의 경로를 갖는, k에서 시작하고 0에서 끝나는 경로들의 곱들의 "합"
그러니까 k=0,1,2,..,7 일때의 모든 경우의 수를 합한 것이
f(D,0,0)
의 값이 되는 것이다.
이러한 형태로 재귀적으로 쪼개고 쪼개다보면 탈출조건에 도달하여, 연결이 되어있는지까지 가는 것이다.
그렇게 해서 D -> D//2 -> ... -> 2 -> 1 까지 내려가면서 최종적으로 f(D,0,0)
의 값을 재귀적으로 구하는 것이다.
...
하지만 이 경우 함수의 중복 호출이 너무나도 많아진다.
따라서 메모이제이션을 생각해보았다.
이번엔 메모이제이션을 생각해보았다.
m[d][frm][to]
= 함수 f와 같이 frm을 시작점, to를 끝점으로 하는, 거리가 d인 경로의 개수
처음엔 일반적인 메모이제이션을 하려고 배열에 선언을 하고 값을 집어넣으려고 했는데...
N이 너무 커서 배열을 초기화하는 데에도 엄청난 시간이 소요되었다.
그래서 딕셔너리를 이용하여, 사용되는 숫자들에 대해서만 메모이제이션을 하는 방법을 생각했다.
m={}
m[1] = [
[0, 1, 0, 0, 0, 0, 0, 1],
[1, 0, 1, 0, 0, 0, 0, 1],
[0, 1, 0, 1, 0, 0, 1, 1],
[0, 0, 1, 0, 1, 0, 1, 0],
[0, 0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0, 1, 0],
[0, 0, 1, 1, 0, 1, 0, 1],
[1, 1, 1, 0, 0, 0, 1, 0],
]
def f(d, frm, to):
if d <= 1:
return m[d][frm][to]
m.setdefault(d, [[0 for _ in range(N)] for _ in range(N)])
if m[d][frm][to]:
return m[d][frm][to]
# ...(이하 생략)...
이런식으로 함수가 호출되었을 때 만약 m[d]
가 없으면 m[d]
에 8*8 이차원 배열을 기본 값으로 할당한다.
미리 구해놓은 m[d][frm][to]
가 있으면 그것을 반환하여 중복 호출을 막는다.
메모이제이션을 적용하여 코드를 재구성하면 정답 코드이다.
MOD = 1000000007
N = 8
m = {}
D = int(input())
m[1] = [
[0, 1, 0, 0, 0, 0, 0, 1],
[1, 0, 1, 0, 0, 0, 0, 1],
[0, 1, 0, 1, 0, 0, 1, 1],
[0, 0, 1, 0, 1, 0, 1, 0],
[0, 0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0, 1, 0],
[0, 0, 1, 1, 0, 1, 0, 1],
[1, 1, 1, 0, 0, 0, 1, 0],
]
def f(d, frm, to):
if d <= 1:
return m[d][frm][to]
m.setdefault(d, [[0 for _ in range(N)] for _ in range(N)])
if m[d][frm][to]:
return m[d][frm][to]
half = d // 2
other = half + 1 if d % 2 else half # 홀수면 +1
# half <= other
for k in range(N):
m[d][frm][to] += f(half, frm, k) * f(other, k, to)
m[d][frm][to] %= MOD
return m[d][frm][to]
print(f(D, 0, 0))