Fast Fourier Transform

youngyangze·2024년 1월 20일

알고리즘

목록 보기
3/4

Fast Fourier Transform (FFT)

The Fast Fourier Transform (FFT) is an algorithm used to efficiently compute the Discrete Fourier Transform (DFT) of a sequence of values. It is widely used in various fields such as signal processing, image processing, and data compression.

Concept

The concept behind FFT is to transform a Time domain signal into its Frequency domain representation. By doing so, we can analyze the signal's frequency components and extract useful information from it. The FFT algorithm reduces the computational complexity of the DFT from O(N2)O(N^2) to O(NlogN)O(N log N), making it much faster for large input sizes.

Usage

To use the FFT algorithm, you need to follow these steps:

  1. Preprocess the input signal by padding it with zeros to the nearest power of 22. This step ensures that the input length satisfies the requirement of the FFT algorithm.
  2. Apply the FFT algorithm to the padded signal using an appropriate implementation or library function.
  3. Obtain the Frequency domain representation of the signal, which consists of complex numbers representing the magnitude and phase of each frequency component.
  4. Perform further analysis or processing on the Frequency domain representation as needed.

Equation

The equation for the Fast Fourier Transform is as follows:

Xk=n=0N1xnei2πknNX_k = \sum_{n=0}^{N-1} x_n \cdot e^{-i \cdot 2\pi \cdot \frac{k \cdot n}{N}}

Signal kk could be 0,1,....,N10,1,....,N-1 which is complex number.

Where:

  • XkX_k is the complex-valued frequency component at index kk,
  • xnx_n is the input signal at index nn,
  • NN is the length of the input signal,
  • kk ranges from 00 to N1N-1.

This equation calculates the contribution of each Time Domain sample to each frequency component.

Code

We can implement FFT in C++.

typedef complex<double> cpx;
namespace lab {
    void FFT(vector<cpx>& vec, bool inv) {
        int len = vec.size();

        for (int i = 1, j = 0; i < len; i++) {
            int bit = len / 2;

            while (j >= bit) {
                j -= bit;
                bit /= 2;
            }
            j += bit;

            if (i < j) {
                swap(vec[i], vec[j]);
            }
        }

        for (int k = 1; k < len; k *= 2) {
            double angle = (inv ? acos(-1) / k : -acos(-1) / k);
            cpx w(cos(angle), sin(angle));

            for (int i = 0; i < len; i += k * 2) {
                cpx z(1, 0);

                for (int j = 0; j < k; j++) {
                    cpx even = vec[i + j];
                    cpx odd = vec[i + j + k];

                    vec[i + j] = even + z * odd;
                    vec[i + j + k] = even - z * odd;

                    z *= w;
                }
            }
        }

        if (inv) {
            for (int i = 0; i < len; i++) {
                vec[i] /= len;
            }
        }
    }
    vector<int> mul(vector<int>& vec, vector<int>& vec2) {
        vector<cpx> vc(vec.begin(), vec.end());
        vector<cpx> uc(vec2.begin(), vec2.end());

        int n = 2;
        while (n < vec.size() + vec2.size())
            n *= 2;

        vc.resize(n);
        FFT(vc, false);
        uc.resize(n);
        FFT(uc, false);

        for (int i = 0; i < n; i++)
            vc[i] *= uc[i];
        FFT(vc, true);

        vector<int> w(n);
        for (int i = 0; i < n; i++)
            w[i] = round(vc[i].real());

        return w;
    }
} // namespace lab

References

Wikipedia
My Github

Also you should watch

3Blue1Brown's video

0개의 댓글