[BOJ 13575] - 보석 가게 (고속 푸리에 변환, 수학, C++, Python)

보양쿠·2023년 10월 16일
0

BOJ

목록 보기
205/260
post-custom-banner

BOJ 13575 - 보석 가게 링크
(2023.10.16 기준 D5)

문제

N개의 보석이 있으며 각각의 가치는 ai이며, 개수는 무한대이다.
K개의 보석을 가져갈 수 있을 때, 가능한 보석의 가치의 합을 모두 출력

알고리즘

FFT

풀이

BOJ 10531 - Golf Bot 풀이를 참고하자.
가능한 가치로 하여금 차수의 계수를 1로 잡고, 자기 자신과의 합성곱을 구하면 된다. K개를 가져갈 수 있으므로 A^K를 구한다고 생각하면 된다.

하지만 이를 naive하게 K-1번 합성곱을 구하면 TLE이므로 분할 정복을 이용한 빠른 거듭제곱 방식을 똑같이 합성곱에 적용하면 된다.

그리고 최적화를 위해 거듭제곱을 할 때마다 각 계수를 최대 1로 다시 축소시켜주자. 그러면 TLE와 MLE를 피할 수 있을 것이다.

코드

  • C++
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const ll primitive_root = 3;

// 빠른 거듭제곱 (수)
ll fpow1(ll x, ll n, ll mod){
    ll result = 1;
    while (n){
        if (n & 1) result = result * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return result;
}

// 거듭제곱을 이용한 NTT
void ntt(vector<ll> &A, bool inv = false){
    int n = A.size();
    for (int i = 1, j = 0, bit; i < n; i++){
        bit = n >> 1;
        while (j >= bit) j -= bit, bit >>= 1;
        j += bit;
        if (i < j) swap(A[i], A[j]);
    }

    ll z, w, tmp;
    for (int s = 2; s <= n; s <<= 1){
        z = fpow1(primitive_root, (mod - 1) / s, mod);
        if (inv) z = fpow1(z, mod - 2, mod);
        for (int i = 0; i < n; i += s){
            w = 1;
            for (int j = i; j < i + (s >> 1); j++){
                tmp = A[j + (s >> 1)] * w;
                A[j + (s >> 1)] = (A[j] - tmp) % mod;
                A[j] = (A[j] + tmp) % mod;
                w = (w * z) % mod;
            }
        }
    }

    for (auto &x: A) if (x < 0) x += mod;

    if (inv){
        ll inv_n = fpow1(n, mod - 2, mod);
        for (auto &x: A) x = x * inv_n % mod;
    }
}

// a와 b의 합성곱을 a에 저장
void mul(vector<ll> &a, vector<ll> b){
    ntt(a); ntt(b);
    for (int i = 0, M = a.size(); i < M; i++) a[i] *= b[i];
    ntt(a, true);

    // MLE를 피하기 위해 각 자리의 결과를 1로 축소
    for (int i = 0, M = a.size(); i < M; i++) if (a[i]) a[i] = 1;
}

// 빠른 거듭제곱 (다항식)
vector<ll> fpow2(vector<ll> &A, int K){
    if (!K) return A;
    vector<ll> result(A);
    while (K){
        if (K & 1) mul(result, A);
        mul(A, A);
        K >>= 1;
    }
    return result;
}


int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    int N, K; cin >> N >> K;
    int a[N]; for (int i = 0; i < N; i++) cin >> a[i];

    int M = 1 << (int)ceil(log2(1001 * K)); // 합성곱을 구할 다항식의 길이 합(1001 * K)보다 큰, 가장 작은 2의 거듭제곱
    vector<ll> A(M, 0);
    for (auto ai: a) A[ai] = 1;

    // K-1번의 자기 자신과의 합성곱
    A = fpow2(A, K - 1);

    // 0보다 크면 가능한 보석 가치의 합이다.
    for (int i = 0; i < M; i++) if (A[i]) cout << i << ' ';
}
  • Python (PyPy3)
import sys; input = sys.stdin.readline
from math import ceil, log2

mod = 998244353
primitive_root = 3

# 거듭제곱을 이용한 NTT
def ntt(A, inv = False):
    n = len(A)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            A[i], A[j] = A[j], A[i]

    s = 2
    while s <= n:
        z = pow(primitive_root, (mod - 1) // s, mod)
        if inv:
            z = pow(z, mod - 2, mod)
        for i in range(0, n, s):
            w = 1
            for j in range(i, i + (s >> 1)):
                tmp = A[j + (s >> 1)] * w
                A[j + (s >> 1)] = (A[j] - tmp) % mod
                A[j] = (A[j] + tmp) % mod
                w = (w * z) % mod
        s <<= 1

    for i in range(n):
        if A[i] < 0:
            A[i] += mod

    if inv:
        inv_n = pow(n, mod - 2, mod)
        for i in range(n):
            A[i] = A[i] * inv_n % mod

# a와 b의 합성곱을 a에 저장
def mul(a, b):
    ntt(a); ntt(b)
    for i in range(M):
        a[i] *= b[i]
    ntt(a, True)

    # MLE를 피하기 위해 각 자리의 결과를 1로 축소
    for i in range(M):
        if a[i]:
            a[i] = 1

# 빠른 거듭제곱
def fpow(A, K):
    if not K:
        return A
    result = A.copy()
    while K:
        if K & 1:
            mul(result, A.copy())
        mul(A, A.copy())
        K >>= 1
    return result

N, K = map(int, input().split())
a = list(map(int, input().split()))

M = 1 << ceil(log2(1001 * K)) # 합성곱을 구할 다항식의 길이 합(1001 * K)보다 큰, 가장 작은 2의 거듭제곱
A = [0] * M
for ai in a:
    A[ai] = 1

# K-1번의 자기 자신과의 합성곱
A = fpow(A, K - 1)

# 0보다 크면 가능한 보석 가치의 합이다.
for i in range(M):
    if A[i]:
        print(i, end = ' ')
profile
GNU 16 statistics & computer science
post-custom-banner

0개의 댓글