Berlekamp-Massey Algorithm (BMA)
-
유한체 Fq 위에서 작동한다.
-
동차 선형 점화 수열 sn이 주어지면 해당 수열을 생성하는 k차 동차 선형 점화식을 반환하는 알고리즘이다.
-
시간 복잡도는 O(N2)이다.
-
주어진 수열을 순회하며 연산한다.
본 글에서 다루는 알고리즘은 원래 BMA와는 약간 다르다. DP에서 활용할 수 있는 형태의 점화식 계수 수열을 반환하는 것을 목표로 하기에 계수의 부호 등 일부는 원래 BMA의 그것과 정반대이고, 계수 다항식이 아니라 수열을 사용한다. 원래 BMA의 방식을 정확히 알고 싶다면 다른 문서를 참고하길 바란다.
각 단계 n(n=0,1,2,⋯)마다 n까지의 최적 선형 점화식 계수 수열 Cn={cin}과 그 길이인 Ln이 존재한다. c1n,c2n,⋯,cLnn이 최종적으로 구한 선형 점화식의 계수가 될 것이다.
c00=−1, 모든 n=0에 대해 cn0=0,L0=0으로 설정한다.
이를 이용해 sn을 예측한다. 예측값 s^n=i=1∑Lncinsn−i이다.
오차 dn을 계산한다. dn=sn−s^n이다.
만약 dn=0이라면 Cn이 sn을 완벽히 예측한 것이다. 점화식을 수정할 필요가 없으므로 Cn+1=Cn,Ln+1=Ln으로 설정하고 다음 n으로 넘어간다.
dn=0이라면 Cn을 수정해야 한다. 이를 위해 새로운 변수 m을 사용해야 한다.
m은 이전에 오류가 발생했던 인덱스, 즉 dm=0이며 n−m이 최소인 수이다.
m을 사용해 Cn+1을 새롭게 설정한다.
해당 식은 cin+1=cin−dmdnci−n+mm이다.
어떻게 새로 만든 Cn+1이 올바르다고 보장할 수 있을까?
Cn+1은 Ln+1≤N≤n인 모든 N에 대해 sN=i=1∑Ln+1cin+1sN−i를 만족해야 할 것이다.
i=1∑Ln+1cin+1sN−i=i=1∑Ln+1cinsN−i−dmdni=1∑Ln+1ci−n+mmsN−i=sN−dmdni=1∑Ln+1ci−n+mmsN−i
n<0,Lm<n에서 cnm=0이므로
i=1∑Ln+1ci−n+mmsN−i=i=n−m∑Ln+1ci−n+mmsN−i=i=n−m∑Lm+n−mci−n+mmsN−i=j=0∑LmcjmsN−j−n+m
그러므로 Ln+1≤N≤n−1인 N에 대해
j=0∑LmcjmsN−j−n+m=j=1∑LmcjmsN−j−n+m−sN−n+m=sN−n+m−sN−n+m=0이므로 sN=i=1∑Ln+1cin+1sN−i
N=n에 대해
i=1∑Ln+1cin+1sN−i=i=1∑Ln+1cinsN−i−dmdni=1∑Ln+1ci−n+mmsN−i=s^n−(sn−s^n)dm1i=1∑Ln+1ci−n+mmsN−i
i=1∑Ln+1ci−n+mmsN−i=i=n−m∑Ln+1ci−n+mmsN−i=i=n−m∑Lm+n−mci−n+mmsN−i=j=0∑Lmcjmsm−j=s^m−sm=−dm
∴i=1∑Ln+1cin+1sN−i=s^n+sn−s^n=sn
따라서 Ln+1≤N≤n인 모든 N에 대해
sN=i=1∑Ln+1cin+1sN−i이 성립하므로 Cn+1은 적절하다.
이때 2Ln≤ (수열의 항 수)를 만족해야 하는데, 그 이유는 Cn이 적합한지 확인하기 위해서는 Ln개의 검산 항이 필요하기 때문에 초기항 Ln개와 검산 항 Ln개, 총 2Ln개를 더해야 하기 때문이다.
Ln+1=n+1−Ln으로 나타낼 수 있는데, 이 또한 증명해 보자.
cin+1=cin−dmdnci−n+mm이므로
Ln+1=max{Ln,n−m+Lm}이다.
m의 초기값은 −1, Lm의 초기값은 0으로 설정했을 때 처음 오류가 발생한 n에 대해 Ln+1=max{0,n+1}=n+1=n+1−Ln
m에서 Lm+1=m+1−Lm이 성립할 때, Ln+1=max{Ln,n−m+Lm}=max{Ln,n+1−Lm+1}=max{Ln,n+1−Ln}
(n+1−Ln)−Ln=n+1−2Ln≥1>0
∴max{Ln,n+1−Ln}=n+1−Ln,Ln+1=n+1−Ln
수학적 귀납법에 의해 모든 n=0,1,⋯에 대해 Ln+1=n+1−Ln
이를 코드화시켜 보자.
모든 C,L,d를 저장하면 공간 복잡도가 지나치게 커지므로 필요한 Cm,Ln,dm과 m만 저장한다. Cm은 B, Ln은 L, dm은 w라고 하자.
입력은 수열 s와 유한체의 크기 mod면 충분하다.
def berl(S, mod = 998244353):
N = len(S)
M = -1
C = [0] * (N + 1)
C[0] = -1
B = [0] * (N + 1)
L = 0
d = 0
w = 1
for i in range(N):
d = S[i]
for j in range(1, L + 1):
d = (d - C[j] * S[i - j]) % mod
if d != 0:
tmp = C[:]
for j in range(i - M, N + 1):
C[j] = (C[j] - d * pow(w, -1, mod) * B[j - i + M]) % mod
if 2 * L <= i:
L = i + 1 - L
M = i
B = tmp
w = d
return C[1:L+1]
위의 코드를 필요에 따라 적절히 수정하여 사용하면 된다.