BOJ 22289 - 큰 수 곱셈 (3) 링크
(2023.04.19 기준 P1)
수의 길이가 최대 1,000,000인 음이 아닌 두 정수가 주어질 때, 두 정수의 곱 출력
그냥 곱셈은 O(N^2) 이므로 TLE. O(NlgN)인 FFT를 이용하여 곱셈을 해보자.
문제와 풀이는 큰 수 곱셈 (2)와 거의 같다.
아니, C++은 똑같은 코드를 내도 AC다.그런데 Pypy3로는 TLE가 난다. 복소수 계산이 많아서 그런 듯 하다.
그러므로 NTT를 이용하여 풀면 된다. 덤으로 C++도 NTT를 이용하여 코드를 짜보았다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const ll primitive_root = 3;
// 빠른 거듭제곱
ll fpow(ll x, ll n, ll mod){
ll result = 1;
while (n){
if (n & 1) result = result * x % mod;
x = x * x % mod;
n >>= 1;
}
return result;
}
// 거듭제곱을 이용한 NTT
void ntt(vector<ll> &A, bool inv = false){
int n = A.size();
for (int i = 1, j = 0, bit; i < n; i++){
bit = n >> 1;
while (j >= bit) j -= bit, bit >>= 1;
j += bit;
if (i < j) swap(A[i], A[j]);
}
ll z, w, tmp;
for (int s = 2; s <= n; s <<= 1){
z = fpow(primitive_root, (mod - 1) / s, mod);
if (inv) z = fpow(z, mod - 2, mod);
for (int i = 0; i < n; i += s){
w = 1;
for (int j = i; j < i + (s >> 1); j++){
tmp = A[j + (s >> 1)] * w;
A[j + (s >> 1)] = (A[j] - tmp) % mod;
A[j] = (A[j] + tmp) % mod;
w = (w * z) % mod;
}
}
}
for (auto &x: A) if (x < 0) x += mod;
if (inv){
ll inv_n = fpow(n, mod - 2, mod);
for (auto &x: A) x = x * inv_n % mod;
}
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
string SA, SB;
cin >> SA >> SB;
// 필요한 길이(a + b - 1)보다 크면서, 가장 작은 2의 거듭제곱
int a = SA.size(), b = SB.size(), N = 1 << (int)ceil(log2(a + b - 1));
// 거듭제곱을 이용한 NTT를 위해 길이를 2의 거듭제곱이 되게끔 맞춘다.
vector<ll> A(N, 0);
for (int i = 0; i < a; i++) A[i] = SA[i] - '0';
vector<ll> B(N, 0);
for (int i = 0; i < b; i++) B[i] = SB[i] - '0';
// A와 B의 합성곱 구하기
ntt(A); ntt(B);
for (int i = 0; i < N; i++) A[i] *= B[i];
ntt(A, true);
// a의 길이의 정수와 b의 길이의 정수의 곱은 최대 a + b다.
ll result[a + b] = {0, };
for (int i = 0; i < a + b - 1; i++) result[i + 1] = A[i];
// 올림 처리
for (int i = a + b - 1; i > 0; i--){
result[i - 1] += result[i] / 10;
result[i] %= 10;
}
if (result[0]) cout << result[0]; // 첫 자리는 0일 수도 있다.
for (int i = 1; i < a + b; i++) cout << result[i];
}
import sys; input = sys.stdin.readline
from math import ceil, log2
mod = 998244353
primitive_root = 3
# 거듭제곱을 이용한 NTT
def ntt(A, inv = False):
n = len(A)
j = 0
for i in range(1, n):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
A[i], A[j] = A[j], A[i]
s = 2
while s <= n:
z = pow(primitive_root, (mod - 1) // s, mod)
if inv:
z = pow(z, mod - 2, mod)
for i in range(0, n, s):
w = 1
for j in range(i, i + (s >> 1)):
tmp = A[j + (s >> 1)] * w
A[j + (s >> 1)] = (A[j] - tmp) % mod
A[j] = (A[j] + tmp) % mod
w = (w * z) % mod
s <<= 1
for i in range(n):
if A[i] < 0:
A[i] += mod
if inv:
inv_n = pow(n, mod - 2, mod)
for i in range(n):
A[i] = A[i] * inv_n % mod
A, B = input().split()
A = list(map(int, A))
B = list(map(int, B))
a = len(A)
b = len(B)
N = 1 << ceil(log2(a + b - 1)) # 필요한 길이(a + b - 1)보다 크면서, 가장 작은 2의 거듭제곱
# 거듭제곱을 이용한 NTT를 위해 길이를 2의 거듭제곱이 되게끔 맞춘다.
A += [0] * (N - a)
B += [0] * (N - b)
# A와 B의 합성곱 구하기
ntt(A); ntt(B)
for i in range(N):
A[i] *= B[i]
ntt(A, True)
# a의 길이의 정수와 b의 길이의 정수의 곱은 최대 a + b다.
result = [0] * (a + b)
for i in range(a + b - 1):
result[i + 1] = A[i]
# 올림 처리
for i in range(a + b - 1, 0, -1):
result[i - 1] += result[i] // 10
result[i] %= 10
if result[0]: # 첫 자리는 0일 수도 있다.
print(result[0], end = '')
print(*result[1:], sep = '')
위는 복소수 FFT
위는 NTT
시간과 메모리가 차이가 나긴 한다.