[BOJ 22289] - 큰 수 곱셈 (3) (고속 푸리에 변환, 수학, C++, Python)

보양쿠·2023년 4월 19일
0

BOJ

목록 보기
107/260
post-custom-banner

BOJ 22289 - 큰 수 곱셈 (3) 링크
(2023.04.19 기준 P1)

문제

수의 길이가 최대 1,000,000인 음이 아닌 두 정수가 주어질 때, 두 정수의 곱 출력

알고리즘

그냥 곱셈은 O(N^2) 이므로 TLE. O(NlgN)인 FFT를 이용하여 곱셈을 해보자.

풀이

문제와 풀이는 큰 수 곱셈 (2)와 거의 같다.
아니, C++은 똑같은 코드를 내도 AC다.

그런데 Pypy3로는 TLE가 난다. 복소수 계산이 많아서 그런 듯 하다.
그러므로 NTT를 이용하여 풀면 된다. 덤으로 C++도 NTT를 이용하여 코드를 짜보았다.

코드

  • C++
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const ll primitive_root = 3;

// 빠른 거듭제곱
ll fpow(ll x, ll n, ll mod){
    ll result = 1;
    while (n){
        if (n & 1) result = result * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return result;
}

// 거듭제곱을 이용한 NTT
void ntt(vector<ll> &A, bool inv = false){
    int n = A.size();
    for (int i = 1, j = 0, bit; i < n; i++){
        bit = n >> 1;
        while (j >= bit) j -= bit, bit >>= 1;
        j += bit;
        if (i < j) swap(A[i], A[j]);
    }

    ll z, w, tmp;
    for (int s = 2; s <= n; s <<= 1){
        z = fpow(primitive_root, (mod - 1) / s, mod);
        if (inv) z = fpow(z, mod - 2, mod);
        for (int i = 0; i < n; i += s){
            w = 1;
            for (int j = i; j < i + (s >> 1); j++){
                tmp = A[j + (s >> 1)] * w;
                A[j + (s >> 1)] = (A[j] - tmp) % mod;
                A[j] = (A[j] + tmp) % mod;
                w = (w * z) % mod;
            }
        }
    }

    for (auto &x: A) if (x < 0) x += mod;

    if (inv){
        ll inv_n = fpow(n, mod - 2, mod);
        for (auto &x: A) x = x * inv_n % mod;
    }
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    string SA, SB;
    cin >> SA >> SB;

    // 필요한 길이(a + b - 1)보다 크면서, 가장 작은 2의 거듭제곱
    int a = SA.size(), b = SB.size(), N = 1 << (int)ceil(log2(a + b - 1));

    // 거듭제곱을 이용한 NTT를 위해 길이를 2의 거듭제곱이 되게끔 맞춘다.
    vector<ll> A(N, 0);
    for (int i = 0; i < a; i++) A[i] = SA[i] - '0';
    vector<ll> B(N, 0);
    for (int i = 0; i < b; i++) B[i] = SB[i] - '0';

    // A와 B의 합성곱 구하기
    ntt(A); ntt(B);
    for (int i = 0; i < N; i++) A[i] *= B[i];
    ntt(A, true);

    // a의 길이의 정수와 b의 길이의 정수의 곱은 최대 a + b다.
    ll result[a + b] = {0, };
    for (int i = 0; i < a + b - 1; i++) result[i + 1] = A[i];

    // 올림 처리
    for (int i = a + b - 1; i > 0; i--){
        result[i - 1] += result[i] / 10;
        result[i] %= 10;
    }

    if (result[0]) cout << result[0]; // 첫 자리는 0일 수도 있다.
    for (int i = 1; i < a + b; i++) cout << result[i];
}
  • Python (PyPy)
import sys; input = sys.stdin.readline
from math import ceil, log2

mod = 998244353
primitive_root = 3

# 거듭제곱을 이용한 NTT
def ntt(A, inv = False):
    n = len(A)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            A[i], A[j] = A[j], A[i]

    s = 2
    while s <= n:
        z = pow(primitive_root, (mod - 1) // s, mod)
        if inv:
            z = pow(z, mod - 2, mod)
        for i in range(0, n, s):
            w = 1
            for j in range(i, i + (s >> 1)):
                tmp = A[j + (s >> 1)] * w
                A[j + (s >> 1)] = (A[j] - tmp) % mod
                A[j] = (A[j] + tmp) % mod
                w = (w * z) % mod
        s <<= 1

    for i in range(n):
        if A[i] < 0:
            A[i] += mod

    if inv:
        inv_n = pow(n, mod - 2, mod)
        for i in range(n):
            A[i] = A[i] * inv_n % mod

A, B = input().split()
A = list(map(int, A))
B = list(map(int, B))

a = len(A)
b = len(B)
N = 1 << ceil(log2(a + b - 1)) # 필요한 길이(a + b - 1)보다 크면서, 가장 작은 2의 거듭제곱

# 거듭제곱을 이용한 NTT를 위해 길이를 2의 거듭제곱이 되게끔 맞춘다.
A += [0] * (N - a)
B += [0] * (N - b)

# A와 B의 합성곱 구하기
ntt(A); ntt(B)
for i in range(N):
    A[i] *= B[i]
ntt(A, True)

# a의 길이의 정수와 b의 길이의 정수의 곱은 최대 a + b다.
result = [0] * (a + b)
for i in range(a + b - 1):
    result[i + 1] = A[i]

# 올림 처리
for i in range(a + b - 1, 0, -1):
    result[i - 1] += result[i] // 10
    result[i] %= 10

if result[0]: # 첫 자리는 0일 수도 있다.
    print(result[0], end = '')
print(*result[1:], sep = '')

여담

위는 복소수 FFT

위는 NTT

시간과 메모리가 차이가 나긴 한다.

profile
GNU 16 statistics & computer science
post-custom-banner

0개의 댓글