PS를 위한 Berlekamp-Massey Algorithm과 그 Python 구현

루트삼·2025년 5월 5일

알고리즘

목록 보기
1/1
post-thumbnail

Berlekamp-Massey Algorithm (BMA)

  • 유한체 Fq\mathbb F_q 위에서 작동한다.

  • 동차 선형 점화 수열 sns_n이 주어지면 해당 수열을 생성하는 kk차 동차 선형 점화식을 반환하는 알고리즘이다.

  • 시간 복잡도는 O(N2)\mathcal O (N^2)이다.

  • 주어진 수열을 순회하며 연산한다.

본 글에서 다루는 알고리즘은 원래 BMA와는 약간 다르다. DP에서 활용할 수 있는 형태의 점화식 계수 수열을 반환하는 것을 목표로 하기에 계수의 부호 등 일부는 원래 BMA의 그것과 정반대이고, 계수 다항식이 아니라 수열을 사용한다. 원래 BMA의 방식을 정확히 알고 싶다면 다른 문서를 참고하길 바란다.

각 단계 n  (n=0,1,2,)n\; (n=0,1,2,\cdots)마다 nn까지의 최적 선형 점화식 계수 수열 Cn={cin}C_n=\{c_i^{n}\}과 그 길이인 LnL_n이 존재한다. c1n,c2n,,cLnnc_1^n, c_2^n,\cdots,c_{L_n}^n이 최종적으로 구한 선형 점화식의 계수가 될 것이다.

c00=1,c_0^0=-1, 모든 n0n\neq 0에 대해 cn0=0,  L0=0c_n^0=0, \; L_0 = 0으로 설정한다.

이를 이용해 sns_n을 예측한다. 예측값 s^n=i=1Lncinsni\hat s_n=\displaystyle\sum_{i=1}^{L_n}c_i^ns_{n-i}이다.

오차 dnd_n을 계산한다. dn=sns^nd_n = s_n - \hat s_n이다.

만약 dn=0d_n = 0이라면 CnC_nsns_n을 완벽히 예측한 것이다. 점화식을 수정할 필요가 없으므로 Cn+1=Cn,Ln+1=LnC_{n+1} = C_n, L_{n+1} = L_n으로 설정하고 다음 nn으로 넘어간다.

dn0d_n \neq 0이라면 CnC_n을 수정해야 한다. 이를 위해 새로운 변수 mm을 사용해야 한다.

mm은 이전에 오류가 발생했던 인덱스, 즉 dm0d_m \neq 0이며 nmn - m이 최소인 수이다.

mm을 사용해 Cn+1C_{n+1}을 새롭게 설정한다.

해당 식은 cin+1=cindndmcin+mmc_i^{n+1} = c_i^n - \displaystyle\frac{d_n}{d_m}c^m_{i-n+m}이다.

어떻게 새로 만든 Cn+1C_{n+1}이 올바르다고 보장할 수 있을까?
Cn+1C_{n+1}Ln+1NnL_{n+1} \leq N \leq n인 모든 NN에 대해 sN=i=1Ln+1cin+1sNis_N = \displaystyle\sum^{L_{n+1}}_{i=1}{c_i^{n+1}s_{N-i}}를 만족해야 할 것이다.

i=1Ln+1cin+1sNi=i=1Ln+1cinsNidndmi=1Ln+1cin+mmsNi=sNdndmi=1Ln+1cin+mmsNi\displaystyle\sum^{L_{n+1}}_{i=1}{c^{n+1}_is_{N-i}}=\displaystyle\sum^{L_{n+1}}_{i=1}{c^n_is_{N-i}}-\displaystyle\frac{d_n}{d_m}\displaystyle\sum^{L_{n+1}}_{i=1}c^m_{i-n+m}s_{N-i}=s_N-\displaystyle\frac{d_n}{d_m}\displaystyle\sum^{L_{n+1}}_{i=1}c^m_{i-n+m}s_{N-i}

n<0,Lm<nn<0,L_m< n에서 cnm=0c^m_n=0이므로

i=1Ln+1cin+mmsNi=i=nmLn+1cin+mmsNi=i=nmLm+nmcin+mmsNi=j=0LmcjmsNjn+m\displaystyle\sum^{L_{n+1}}_{i=1}c^m_{i-n+m}s_{N-i}=\displaystyle\sum^{L_{n+1}}_{i=n-m}c^m_{i-n+m}s_{N-i}=\displaystyle\sum^{L_m+n-m}_{i=n-m}{c^m_{i-n+m}s_{N-i}}=\displaystyle\sum^{L_m}_{j=0}{c^m_js_{N-j-n+m}}

그러므로 Ln+1Nn1L_{n+1} \leq N \leq n - 1NN에 대해

j=0LmcjmsNjn+m=j=1LmcjmsNjn+msNn+m=sNn+msNn+m=0\displaystyle\sum^{L_m}_{j=0}{c^m_js_{N-j-n+m}}=\displaystyle\sum^{L_m}_{j=1}{c^m_js_{N-j-n+m}}-s_{N-n+m}=s_{N-n+m}-s_{N-n+m}=0이므로 sN=i=1Ln+1cin+1sNis_N = \displaystyle\sum^{L_{n+1}}_{i=1}{c_i^{n+1}s_{N-i}}

N=nN = n에 대해

i=1Ln+1cin+1sNi=i=1Ln+1cinsNidndmi=1Ln+1cin+mmsNi=s^n(sns^n)1dmi=1Ln+1cin+mmsNi\displaystyle\sum^{L_{n+1}}_{i=1}{c^{n+1}_is_{N-i}}=\displaystyle\sum^{L_{n+1}}_{i=1}{c^n_is_{N-i}}-\displaystyle\frac{d_n}{d_m}\displaystyle\sum^{L_{n+1}}_{i=1}c^m_{i-n+m}s_{N-i}=\hat s_n - (s_n - \hat s_n ) \displaystyle \frac{1}{d_m}\displaystyle\sum^{L_{n+1}}_{i=1}c^m_{i-n+m}s_{N-i}

i=1Ln+1cin+mmsNi=i=nmLn+1cin+mmsNi=i=nmLm+nmcin+mmsNi=j=0Lmcjmsmj=s^msm=dm\displaystyle\sum^{L_{n+1}}_{i=1}c^m_{i-n+m}s_{N-i}=\displaystyle\sum^{L_{n+1}}_{i=n-m}c^m_{i-n+m}s_{N-i}=\displaystyle\sum^{L_m+n-m}_{i=n-m}{c^m_{i-n+m}s_{N-i}}=\displaystyle\sum^{L_m}_{j=0}{c^m_js_{m-j}}=\hat s_m - s_m = -d_m

i=1Ln+1cin+1sNi=s^n+sns^n=sn\therefore \displaystyle\sum^{L_{n+1}}_{i=1}{c^{n+1}_is_{N-i}}=\hat s_n + s_n - \hat s_n = s_n

따라서 Ln+1NnL_{n+1} \leq N \leq n인 모든 NN에 대해

sN=i=1Ln+1cin+1sNis_N = \displaystyle\sum^{L_{n+1}}_{i=1}{c_i^{n+1}s_{N-i}}이 성립하므로 Cn+1C_{n+1}은 적절하다.

이때 2Ln2L_n\leq (수열의 항 수)를 만족해야 하는데, 그 이유는 CnC_n이 적합한지 확인하기 위해서는 LnL_n개의 검산 항이 필요하기 때문에 초기항 LnL_n개와 검산 항 LnL_n개, 총 2Ln2L_n개를 더해야 하기 때문이다.

Ln+1=n+1LnL_{n+1}=n + 1 - L_n으로 나타낼 수 있는데, 이 또한 증명해 보자.

cin+1=cindndmcin+mmc_i^{n+1}=c_i^n - \displaystyle\frac{d_n}{d_m}c^m_{i-n+m}이므로

Ln+1=max{Ln,nm+Lm}L_{n+1}=\max\{L_n, n-m+L_m\}이다.

mm의 초기값은 1-1, LmL_m의 초기값은 00으로 설정했을 때 처음 오류가 발생한 nn에 대해 Ln+1=max{0,n+1}=n+1=n+1LnL_{n+1}=\max\{0, n+1\}=n+1=n+1-L_n

mm에서 Lm+1=m+1LmL_{m+1}=m+1-L_m이 성립할 때, Ln+1=max{Ln,nm+Lm}=max{Ln,n+1Lm+1}=max{Ln,n+1Ln}L_{n+1}=\max\{L_n,n-m+L_m\}=\max\{L_n,n+1-L_{m+1}\}=\max\{L_n,n+1-L_n\}

(n+1Ln)Ln=n+12Ln1>0(n+1-L_n)-L_n=n+1-2L_n\geq1>0
max{Ln,n+1Ln}=n+1Ln,  Ln+1=n+1Ln\therefore \max\{L_n, n+1-L_n\}=n+1-L_n,\;L_{n+1}=n+1-L_n

수학적 귀납법에 의해 모든 n=0,1,n=0,1,\cdots에 대해 Ln+1=n+1LnL_{n+1}=n+1-L_n

이를 코드화시켜 보자.

모든 C,L,dC, L, d를 저장하면 공간 복잡도가 지나치게 커지므로 필요한 Cm,Ln,dmC_m, L_n, d_mmm만 저장한다. CmC_mB,B, LnL_nL,L, dmd_mww라고 하자.

입력은 수열 ss와 유한체의 크기 mod\mathrm{mod}면 충분하다.

def berl(S, mod = 998244353):
    N = len(S)
    M = -1
    C = [0] * (N + 1)
    C[0] = -1
    B = [0] * (N + 1)
    L = 0
    d = 0
    w = 1
    for i in range(N):
        d = S[i]
        for j in range(1, L + 1):
            d = (d - C[j] * S[i - j]) % mod
        if d != 0:
            tmp = C[:]
            for j in range(i - M, N + 1):
                C[j] = (C[j] - d * pow(w, -1, mod) * B[j - i + M]) % mod
            if 2 * L <= i:
                L = i + 1 - L
                M = i
                B = tmp
                w = d
    return C[1:L+1]

위의 코드를 필요에 따라 적절히 수정하여 사용하면 된다.

profile
안녕하세요.

0개의 댓글