시작하며,
틈틈이 SCPC(삼성전자 대학생 프로그래밍 경진대회)를 준비하고 있는데요. 개강하기도 했고 졸업 연구도 병행하고 있어서 방학 때 만큼의 시간 투자가 어렵다는 게 근래 가장 큰 근심입니다. 이번 포스팅에선 2022년도 SCPC 2차 예선에 출제되었던 5번 문제 "황금 카드"에 필요한 fft에 대해 다뤄보려고 합니다. 개인적으로 난이도 극악이었지만 fast Fourier transform을 활용한 고속 다항 곱셈에 대한 지식을 확장할 수 있는 뜻깊은 경험이었습니다. 다시 한번 divide and conquer의 강력함을 느꼈고 fft의 개념을 너무 알고 싶어서 시험공부도 다 뒤로하고 강박적인 며칠을 보냈습니다. 이 글을 끝으로 FFT에 대해 조금 더 숙지하고 떠내보내는 것이 소소한 바람입니다.
참고 자료:
https://speakerdeck.com/wookayin/fast-fourier-transform-algorithm
https://casterian.net/algo/fft.html
https://namnamseo.tistory.com/entry/FFT-in-competitive-programming
https://topology-blog.tistory.com/6
1. DFT, IDFT, polynomial multiplication
Polynomial A(x),B(x)가 주어져 있다고 합시다.
A(x)=a0+a1x+a2x2+...+aN−1xN−1
B(x)=b0+b1x+b2x2+...+bN−1xN−1
(N은 (maxdegreeA∗B)−1 이상 입니다.)
주의!
Cooley-Tukey Algorithm을 활용하기 때문에 N은 2의 거듭제곱수입니다.
A(x)는 다항식, ai는 다항식 A(x)의 i번째 차수의 계수, Ai는 퓨리에 변환된 값들을 나타냅니다.
이때, <a0,a1,...,aN−1>, <b0,b1,...,bN−1>을 각각 A, B의 coefficient vector라고 표현합니다.
다항식의 곱 C=A∗B의 coefficient vector를 naive하게 계산하기 위해서 convolution 연산을 필요로 합니다.
ck=∑i+j=n,0≤i,j≤kaibj
하지만 위의 연산은 k=0,1,...,N−1까지 고려해야하고 ck 하나를 연산하는데 최대 O(N)번의 연산이 필요하므로 위의 연산은 O(N2)의 시간복잡도를 가집니다.
convolution 연산을 Fourier domain에서 곱셈 연산이 되므로 이 성질을 활용해 coefficient vectors를 fourier transform하고 element-wise 곱셈을 마친 뒤 inverse fourier transform을 해서 다항식의 곱셈 값을 구할 수 있습니다. 물론, discrete fourier transform(DFT)을 활용합니다.
DFT: Aj=A(wNj)=∑k=0N−1ak(wNj)k,wN=eN2πi
coefficient vector를 discrete fourier transform 시키면 아래의 변환이 성립합니다.
<a0,a1,...,aN−1> → <<A0,A1,...,AN−1>>
<b0,b1,...,bN−1> → <<B0,B1,...,BN−1>>
한편, C=A∗B의 discrete fourier transform을 생각해봅니다.
Cj=C(wNj)=A(wNj)∗B(wNj)=AjBj
따라서 아래의 식이 성립합니다.
<<C0,C1,...,CN−1>>=<<A0∗B0,A1∗B1,...,AN−1∗BN−1>>
위의 식을 inverse discrete fourier transform 하면 다항식의 곱을 구할 수 있습니다.
<<C0,C1,...,CN−1>> → <c0,c1,...,cN−1>
inverse discrete fourier transform(IDFT)를 유도하기 위해선 DFT의 식에 wˉNlj을 곱해줍니다. 그 뒤 j=0,1,2,...,N−1에 대해서 summation을 취해줍니다.
Aj∗wˉNlj=A(wNj)∗wˉNlj=∑k=0N−1ak(wNj)k∗wˉNlj
∑j=0N−1A(wNj)∗wˉNlj=∑j=0N−1∑k=0N−1ak(wNj)k∗wˉNlj=∑k=0N−1ak∗∑j=0N−1(wNj)k−l=Nal
따라서 IDFT는 아래와 같은 식으로 정리할 수 있습니다.
IDFT: aj=N1∑k=0N−1Ak∗wˉNjk
C=A∗B을 구하는 법을 정리하자면 아래와 같습니다.
1) <a0,a1,...,aN−1> → <<A0,A1,...,AN−1>> (fourier transform)
2) <b0,b1,...,bN−1> → <<B0,B1,...,BN−1>> (fourier transform)
3 )<<C0,C1,...,CN−1>>=<<A0∗B0,A1∗B1,...,AN−1∗BN−1>> (element-wise multiplication)
4) <<C0,C1,...,CN−1>> → <c0,c1,...,cN−1> (inverse fourier transform)
2. FFT
Fast Fourier transform은 discrete fourier transform를 divide and conquer 방식을 활용해서 O(Nlog(N))의 시간복잡도로 디자인했습니다.
Polynomial A(x)에 대해 생각합니다. N은 반드시 2의 거듭제곱 꼴의 자연수입니다.
A(x)=a0+a1x+a2x2+...+aN−1xN−1
다항식 A에 대해서 DFT를 구하기 위해서 wNk, k=0,1,2,...,N−1에 대한 A(wNk) 값을 전부 구해야 됩니다.
한편, A(x)를 even coefficient term, odd coefficient term으로 분할 할 수 있습니다.
A(x)=a0+a1x+a2x2+...+aN−1xN−1=(a0+a2x2+a4x4+...)+(a1x+a3x3+a5x5+...)=(a0+a2x2+a4x4+...)+x∗(a1+a3x2+a5x4+...)=Aeven(x2)+x∗Aodd(x2)
Aeven(x)=a0+a2x+a4x2+...+aN−2xN/2−1
Aodd(x)=a1+a3x+a5x2+...+aN−1xN/2−1
위의 수식을 참고해서 computational thinking을 시도할 수 있습니다.
FFT(A) := polynomial의 coefficient vector A를 parameter로 전달 받아 k=0,1,2,...,N−1에 대해 A(wNk)을 계산하여 vector로 반환합니다.
FFT(Aeven) = polynomial의 coefficient vector Aeven 를 parameter로 전달 받아 k=0,1,2,...,N/2−1에 대해 A(w2Nk)을 계산하여 vector로 반환합니다.
FFT(Aodd) = polynomial의 coefficient vector Aodd 를 parameter로 전달 받아 k=0,1,2,...,N/2−1에 대해 A(w2Nk)을 계산하여 vector로 반환합니다.
NOTE!!
w2Nk=(wNk)2
따라서, 아래의 알고리즘이 성립합니다.
vector<int> FFT(A) {
N = A.size();
if (N == 1)
return A;
A_even, A_odd = split(A);
allocate vectors: fft_A(N), fft_A_even(N/2), fft_A_odd(N/2);
w = exp(2 * pi * i / N);
for (int i = 0; i < N / 2; i++) {
fft_A[i] = fft_A_even[i] + pow(w, i) * fft_A_odd[i];
fft_A[i + N / 2] = fft_A_even[i + N / 2] - pow(w, i) * fft_A_odd[i + N / 2];
}
return fft_A;
}
DFT와 IDFT를 비교해봅시다.
DFT: Aj=∑k=0N−1ak(wNj)k
IDFT: aj=N1∑k=0N−1Ak∗wˉNjk
IDFT는 DFT의 방식과 동일하지만 sampling을 wNj 대신 켤레인 wˉNj으로 취해주면 되고 N으로 나눠주면 됩니다.
4. code review (FFT)
typedef complex<double> base;
void fft(vector<base> &a, vector<base> &A) {
int n = (int) a.size();
if (n == 1) {
A[0] = a[0];
return;
}
vector<base> even(n / 2), odd(n / 2), Even(n / 2), Odd(n / 2);
for (int i = 0; i < n / 2; i++) {
even[i] = a[2 * i];
odd[i] = a[2 * i + 1];
}
fft(even, Even);
fft(odd, Odd);
double th = 2.0 * M_PI / n;
base w = base(cos(th), sin(th));
base z = base(1);
for (int i = 0; i < n / 2; i++) {
A[i] = Even[i] + z * Odd[i];
A[i + n / 2] = Even[i] - z * Odd[i];
z *= w;
}
}
void ifft(vector<base> &A, vector<base> &a) {
reverse(++A.begin(), A.end());
fft(A, a);
int n = (int) a.size();
for (int i = 0; i < n; i++) {
a[i] /= n;
}
}
마치며,
다음엔 비재귀적인 fft 구현까지 다뤄보도록하겠습니다. 오랜만에 신선한 개념을 주입하니 brain teasing 되는 느낌이 아주 좋네요. 한번에 모든 것을 cover up 하기 보다 천천히 그치만 정확하게 짚고 넘어가는 방식을 택하겠습니다!