https://www.acmicpc.net/problem/11385
알고리즘
수학, 정수론, 고속 푸리에 변환, 중국인의 나머지 정리
문제 요약
단순한 다항식 곱셈 문제인데, 계수가 매우 크다는 특징이 있습니다.
1≤ai,bi≤1,000,000 로,
단순 FFT를 돌리면 실수 오차가 누적되어 터지고, 웬만한 소수로 NTT를 돌려도 터집니다.
FFT로도 실수 오차를 줄이는 방식으로 해결할 수 있지만,
이 글에선 다항식에 따라 NTT의 모듈러 소수의 정확한 조건 및,
CRT를 적용하는 방법에 대해서 자세히 다뤄보도록 하겠습니다.
FFT, NTT의 기본 개념에 대해선 알고 있다는 전제 하에 설명하겠습니다.
풀이
문제 조건에 맞는 소수를 어떻게 선택해야 될지 살펴봅시다.
편의상 계수 대신 길이라는 표현을 사용하겠습니다. (길이 = 계수 + 1)
1. 단일 소수 NTT
접근
NTT의 모듈러 소수는 항상 p=a⋅2b+1 형태를 만족해야 합니다.
보편적으로 사용하는 소수를 살펴보면,
p=998,244,353=119⋅223+1
여기서 중요한 건 b 의 값입니다.
b=23 이므로, 다항식의 길이가 223 이하일 때만 안전하다는 뜻입니다.
다행히 대부분의 문제는 해당 조건을 만족합니다.
22289번: 큰 수 곱셈 (3) 의 경우, 길이가 최대 1,000,000 인 두 다항식을 곱하는 문제이므로
곱셈의 결과는 길이가 최대 1,000,000⋅2−1=1,999,999 인 다항식이 되고,
이는 223=8,388,608 보다 훨씬 작습니다.
이 문제도 다항식의 길이는 각각 최대 1,000,001 로, 223 보다 충분히 작습니다.
그런데 해당 소수를 넣으면 터지는 이유가 뭘까요?
소수 조건
구체적으로,
다항식 f(x) 의 길이를 N, 계수의 최대 절댓값을 C 라고 할 때,
해당 다항식에 대해 DFT, IDFT를 수행할 수 있으려면
p≥N⋅C
를 만족해야 합니다. 이제 두 다항식을 곱하는 경우를 살펴봅시다.
다항식 f(x) 의 길이를 N, 계수의 최대 절댓값을 C1,
다항식 g(x) 의 길이를 M, 계수의 최대 절댓값을 C2 라고 할 때,
다항식 f(x)⋅g(x) 의 길이는 N+M−1 , 계수의 최대 절댓값은 C1⋅C2 이므로
소수 p=a⋅2b+1 에 대해
2b≥N+M−1
p≥(N+M−1)⋅C1⋅C2
를 만족해야 합니다. 두 다항식의 길이와 최대 절댓값이 N,C 로 동일할 경우
2b≥2N−1
p≥(2N−1)⋅C2
로 간단하게 정리할 수 있겠네요.
소수 선택
이제 문제의 조건을 살펴봅시다.
두 다항식의 최대 길이가 N=1,000,001, 계수의 최대 절댓값이 C=1,000,000 으로 동일하므로
2b≥2⋅1,000,001−1=2,000,001
p≥(2⋅1,000,001−1)⋅1,000,0002=2,000,001⋅1,000,000,000,000
즉,
b≥21
p≥2,000,001,000,000,000,000
을 만족해야 합니다.
이제 어떤 소수를 선택해야 할지, 왜 특정 소수를 넣으면 틀리는지 명확해졌습니다.
다시 22289번: 큰 수 곱셈 (3) 의 경우, 한 자리씩 끊어서 다항식을 구성할 경우
계수의 절댓값의 최댓값은 C=9 밖에 되지 않으므로,
p≥(2⋅1,000,000−1)⋅92=161,999,919 를 만족하기만 하면 됐었네요.
더 큰 단일 소수로 이 문제를 해결하는 방법을 생각해 봅시다.
p=998,242,353 은 너무 작습니다.
이러한 형태를 만족하는 소수를 찾아내는 방법은 여러 가지가 있는데,
이 부분에 대해선 나중에 설명하도록 하겠습니다.
p=4,603,910,272,195,756,033
이 소수를 살펴봅시다. 먼저 p≥2,000,001,000,000,000,000 조건을 만족하네요.
p−1 을 소인수분해하면
p−1=245⋅32⋅7⋅31⋅67
p=130,851⋅245+1
b=45 이므로 b≥21 조건도 만족합니다.
p 는 작을수록 좋습니다. 주어진 조건과 꽤 가까운 소수이므로, 사용하기 적합합니다.
해당 소수를 사용해 NTT를 돌리면 AC를 받게 됩니다.
2. NTT + CRT
그런데, 방금 말했듯이 소수가 커질수록 연산 속도가 느려집니다.
p 로 나눈 나머지를 취하기 때문에 계산 과정에서의 수가 최대 p−1 까지 커지기 때문입니다.
실제로 제출해보면 속도가 좀 많이 느린 것을 볼 수 있습니다.
CRT 적용
큰 단일 소수를 사용하는 대신, CRT를 사용해 여러 개의 작은 소수들을 사용해 결과를 얻는 방법을 살펴봅시다.
두 다항식 A(x),B(x) 을 곱하는 상황입니다. 각각의 길이는 N,M, 계수의 최대 절댓값은 C1,C2 입니다.
k 개의 소수 p1,p2,…,pk 을 사용한다고 가정해 봅시다.
상수 전처리
먼저 소수들을 전부 곱한 큰 모듈로 P 를 계산해줍시다.
P=i=1∏kpi
뒤에 설명하겠지만, pi 의 크기와 상관없이 P≥(N+M−1)⋅C1⋅C2 만 만족하면 됩니다.
부분 모듈로 P1,P2,…Pk 도 계산해줍니다. Pi 는 pi 을 제외한 모든 소수를 곱한 값입니다.
Pi=piP=p1⋅p2⋯pi−1⋅pi+1⋯pk (i=1,…,k)
마지막으로 모듈로 역원 m1,m2,…,mk 를 구해줍니다.
i=1,…,k 에 대해, Pi 에 대한 pi 모듈로의 곱셈 역원 mi 는
mi≡Pi−1modpi
즉, Pi⋅mi≡1modpi 를 만족하는 mi 를 구합니다.
이는 확장 유클리드 알고리즘을 통해 구할 수 있습니다. 구체적으로,
Pi⋅mi+pi⋅k=1 에서
gcd(Pi,pi)=1 이므로,
확장 유클리드 알고리즘을 통해 mi 와 k 를 구할 수 있습니다. k 는 중요하지 않습니다.
곱셈, 계수 복원
준비가 끝났습니다.
각각의 소수들로 NTT 곱셈을 수행합니다.
i=1,…,k 에 대해
Zi(x)≡A(x)⋅B(x)modpi
CRT로, 각 소수 p1,p2,…,pn 에 대해 모듈로 연산을 한 결과 Z1(x),Z2(x),…,Zk(x) 를 이용해
큰 모듈로 P 에 대한 곱셈 결과 Z(x) 를 복원합니다.
Z(x)≡i=1∑kZi(x)⋅Pi⋅mimodP
소수 조건
이제 부분 소수 p1,p2,…,pk 의 조건에 대해 살펴봅시다.
각 소수에 대해 다항식 곱셈을 수행해야 하므로
pi=ai⋅2bi+1 (i=1,…k) 라고 했을 때
2b1,2b2,…,2bk≥N+M−1, 즉
min(b1,b2,…,bk)≥log2(N+M−1)
를 만족해야 합니다. 부분 소수에 대한 개별적인 조건은 이게 끝입니다.
마지막으로, CRT를 통해 큰 모듈로 P 에 대한 계수를 복원할 수 있으므로,
P=i=1∏kpi≥(N+M−1)⋅C1⋅C2
를 만족하도록 소수들을 선택하면 됩니다.
소수 선택
문제 조건에 의해 min(b1,b2,…,bk)≥21 을 만족해야 합니다.
NTT 곱셈을 k 번 해야 하므로, 너무 많은 소수로 쪼개는 것도 비효율적입니다.
일반적으로 k=2 정도가 적당합니다. 두 개의 소수를 선택해 봅시다.
p1=1,300,234,241
p2=1,711,276,033
저는 이렇게 선택했습니다. 조건을 만족하는지 분석해 봅시다.
pi−1 의 소인수분해에 의해
p1=155⋅223+1
p2=51⋅225+1
min(23,25)=23≥21 을 만족하네요.
다음으로 P 를 계산해 봅시다.
P=p1⋅p2=1,300,234,241⋅1,711,276,033
P=2,225,059,693,909,245,953≥2,000,001,000,000,000,000
아주 적당히 만족합니다. 앞서 말했듯 소수는 작을수록 좋으므로 적절하게 잘 선택했네요.
곱셈 수행
앞에서 구한 공식에 따라 곱셈을 수행해 봅시다.
P1=p2=1,711,276,033
P2=p1=1,300,234,241
m1≡1,711,276,033−1mod1,300,234,241
m1=636,849,421
m2≡1,300,234,241−1mod1,711,276,033
m2=873,100,021
Z1(x)≡A(x)⋅B(x)mod1,300,234,241
Z2(x)≡A(x)⋅B(x)mod1,711,276,033
Z(x)≡Z1(x)⋅1,711,276,033⋅636,849,421+Z2(x)⋅1,300,234,241⋅873,100,021mod2,225,059,693,909,245,953
를 하면 AC를 받게 됩니다.
오버플로우는 적절히 타입 변환과 모듈러 분배법칙을 이용해 처리하면 됩니다.
모든 소수는 원시근을 갖고 있으니 찾는 방법에 대해선 따로 설명하지 않겠습니다.
코드 (요약)
#[inline(always)]
fn convolve_at(z: &mut [u64], x: &mut [u64], y: &mut [u64], p: u64, r: u64) -> () {
ntt(x, p, r, false);
ntt(y, p, r, false);
for i in 0..x.len() { z[i] = (x[i] * y[i]) % p; }
ntt(z, p, r, true);
}
const K: usize = 2;
const P: [u64; K] = [1_300_234_241, 1_711_276_033];
const R: [u64; K] = [3, 29];
const WP: u64 = ntt_crt_wp_u64__();
const XP: [u64; K] = ntt_crt_xp_u64__();
const MV: [u64; K] = ntt_crt_mv_u64__();
fn main() -> () {
let n: usize = pr!();
let m: usize = pr!();
let len: usize = n + m + 2;
let size: usize = len.next_power_of_two();
let mut x: Vec<u64> = vec![0; size];
let mut y: Vec<u64> = vec![0; size];
for i in 0..n + 1 { x[i] = pr!(); }
for i in 0..m + 1 { y[i] = pr!(); }
let mut z: Vec<u64> = vec![0; len];
let mut w: Vec<u64> = vec![0; size];
for i in 0..K {
if i < K {
convolve_at(&mut w, &mut x.clone(), &mut y.clone(), P[i], R[i])
} else {
convolve_at(&mut w, &mut x, &mut y, P[i], R[i])
};
for j in 0..len {
z[j] = safe_mod_add(z[j], safe_mod_prod(w[j], safe_mod_prod(XP[i], MV[i], WP), WP), WP);
}
}
let mut res: u64 = 0;
for &v in &z { res ^= v; }
wl!(res);
return;
}
위가 소수 2개 CRT, 아래가 단일 소수 NTT입니다.
여담
제 NTT 구현체가 상당히 느립니다.. 계속 개선해보고 있긴 한데 시간 줄이기가 참 어렵네요.
직접 구현해서 사용하시는 게 좋을 것 같습니다.
긴 글 읽어주셔서 감사합니다!