BOJ 13575 - 보석 가게 링크
(2023.10.16 기준 D5)
N개의 보석이 있으며 각각의 가치는 ai이며, 개수는 무한대이다.
K개의 보석을 가져갈 수 있을 때, 가능한 보석의 가치의 합을 모두 출력
FFT
BOJ 10531 - Golf Bot 풀이를 참고하자.
가능한 가치로 하여금 차수의 계수를 1로 잡고, 자기 자신과의 합성곱을 구하면 된다. K개를 가져갈 수 있으므로 A^K를 구한다고 생각하면 된다.하지만 이를 naive하게 K-1번 합성곱을 구하면 TLE이므로 분할 정복을 이용한 빠른 거듭제곱 방식을 똑같이 합성곱에 적용하면 된다.
그리고 최적화를 위해 거듭제곱을 할 때마다 각 계수를 최대 1로 다시 축소시켜주자. 그러면 TLE와 MLE를 피할 수 있을 것이다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const ll primitive_root = 3;
// 빠른 거듭제곱 (수)
ll fpow1(ll x, ll n, ll mod){
ll result = 1;
while (n){
if (n & 1) result = result * x % mod;
x = x * x % mod;
n >>= 1;
}
return result;
}
// 거듭제곱을 이용한 NTT
void ntt(vector<ll> &A, bool inv = false){
int n = A.size();
for (int i = 1, j = 0, bit; i < n; i++){
bit = n >> 1;
while (j >= bit) j -= bit, bit >>= 1;
j += bit;
if (i < j) swap(A[i], A[j]);
}
ll z, w, tmp;
for (int s = 2; s <= n; s <<= 1){
z = fpow1(primitive_root, (mod - 1) / s, mod);
if (inv) z = fpow1(z, mod - 2, mod);
for (int i = 0; i < n; i += s){
w = 1;
for (int j = i; j < i + (s >> 1); j++){
tmp = A[j + (s >> 1)] * w;
A[j + (s >> 1)] = (A[j] - tmp) % mod;
A[j] = (A[j] + tmp) % mod;
w = (w * z) % mod;
}
}
}
for (auto &x: A) if (x < 0) x += mod;
if (inv){
ll inv_n = fpow1(n, mod - 2, mod);
for (auto &x: A) x = x * inv_n % mod;
}
}
// a와 b의 합성곱을 a에 저장
void mul(vector<ll> &a, vector<ll> b){
ntt(a); ntt(b);
for (int i = 0, M = a.size(); i < M; i++) a[i] *= b[i];
ntt(a, true);
// MLE를 피하기 위해 각 자리의 결과를 1로 축소
for (int i = 0, M = a.size(); i < M; i++) if (a[i]) a[i] = 1;
}
// 빠른 거듭제곱 (다항식)
vector<ll> fpow2(vector<ll> &A, int K){
if (!K) return A;
vector<ll> result(A);
while (K){
if (K & 1) mul(result, A);
mul(A, A);
K >>= 1;
}
return result;
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
int N, K; cin >> N >> K;
int a[N]; for (int i = 0; i < N; i++) cin >> a[i];
int M = 1 << (int)ceil(log2(1001 * K)); // 합성곱을 구할 다항식의 길이 합(1001 * K)보다 큰, 가장 작은 2의 거듭제곱
vector<ll> A(M, 0);
for (auto ai: a) A[ai] = 1;
// K-1번의 자기 자신과의 합성곱
A = fpow2(A, K - 1);
// 0보다 크면 가능한 보석 가치의 합이다.
for (int i = 0; i < M; i++) if (A[i]) cout << i << ' ';
}
import sys; input = sys.stdin.readline
from math import ceil, log2
mod = 998244353
primitive_root = 3
# 거듭제곱을 이용한 NTT
def ntt(A, inv = False):
n = len(A)
j = 0
for i in range(1, n):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
A[i], A[j] = A[j], A[i]
s = 2
while s <= n:
z = pow(primitive_root, (mod - 1) // s, mod)
if inv:
z = pow(z, mod - 2, mod)
for i in range(0, n, s):
w = 1
for j in range(i, i + (s >> 1)):
tmp = A[j + (s >> 1)] * w
A[j + (s >> 1)] = (A[j] - tmp) % mod
A[j] = (A[j] + tmp) % mod
w = (w * z) % mod
s <<= 1
for i in range(n):
if A[i] < 0:
A[i] += mod
if inv:
inv_n = pow(n, mod - 2, mod)
for i in range(n):
A[i] = A[i] * inv_n % mod
# a와 b의 합성곱을 a에 저장
def mul(a, b):
ntt(a); ntt(b)
for i in range(M):
a[i] *= b[i]
ntt(a, True)
# MLE를 피하기 위해 각 자리의 결과를 1로 축소
for i in range(M):
if a[i]:
a[i] = 1
# 빠른 거듭제곱
def fpow(A, K):
if not K:
return A
result = A.copy()
while K:
if K & 1:
mul(result, A.copy())
mul(A, A.copy())
K >>= 1
return result
N, K = map(int, input().split())
a = list(map(int, input().split()))
M = 1 << ceil(log2(1001 * K)) # 합성곱을 구할 다항식의 길이 합(1001 * K)보다 큰, 가장 작은 2의 거듭제곱
A = [0] * M
for ai in a:
A[ai] = 1
# K-1번의 자기 자신과의 합성곱
A = fpow(A, K - 1)
# 0보다 크면 가능한 보석 가치의 합이다.
for i in range(M):
if A[i]:
print(i, end = ' ')