FFT를 통한 Convolution 수행

subin·2023년 9월 10일
0

😊one-more-thing

목록 보기
5/7
post-thumbnail

TODO

Solution

문제의 그림처럼 m x n 행렬에서 m x l 행렬을 슬라이딩 시키면서 그 위치에서 두 행렬의 pointwise multiplication sum을 구한다. 그 값이 w를 넘는 횟수를 구하는 문제이다. naive 하게 구하면 O(nml)로 시간초과가 난다.

즉, 이 문제는 FFT를 통해 Convolution을 수행함으로 빠르게 가능하다.

Convolution 연산은 하나의 수열을 가만히 놔두고 다른 수열이 뒤집어진 채로 슬라이딩 하면서 pointwise multiplication sum을 구하는 형태를 하고 있다.

이는 이 문제에서 원하는 것과 동일하다.

수열이 뒤집힌 채로 슬라이딩을 진행하기 때문에 행렬 P를 좌우로 뒤집은 행렬을 P'라고 하자.
그러면 T와 P'의 행 별로 Convolution을 취하면 각 행에서 Wk를 계산하는 데에 사용되는 pointwise multiplication sum을 구할 수 있다. 이에 걸리는 시간 복잡도는 O(m(n+l)log(n+l))이다.

그리고 Wk를 하나 구하는 데에는 각 행의 원소를 더해주는 데에 O(m)이 걸리고 Wk는 O(n)게 만큼 구해야되기 때문에 O(nm)이 걸린다.

이러한 방법으로 문제를 해결할 수 있다.
이를 python 코드로 표현해보자.

import cmath
import sys

input = sys.stdin.readline

def fft(a, inv=False):
    n = len(a)
    b = a.copy()
    
    for i in range(n):
        sz = n
        shift = 0
        idx = i
        while sz > 1:
            if idx & 1:
                shift += sz >> 1
            idx >>= 1
            sz >>= 1
        a[shift + idx] = b[i]

    i = 1
    while i < n:
        x = cmath.pi / i if inv else -cmath.pi / i
        w = cmath.cos(x) + cmath.sin(x) * 1j

        for j in range(0, n, i << 1):
            th = 1 + 0j
            for k in range(i):
                tmp = a[i + j + k] * th
                a[i + j + k] = a[j + k] - tmp
                a[j + k] += tmp
                th *= w
        i <<= 1

    if inv:
        for i in range(n):
            a[i] /= n


def convolution(a, b):
    N = 1
    while N < len(a) + len(b):
        N *= 2
    
    a += [complex(0)] * (N - len(a))
    b += [complex(0)] * (N - len(b))
    
    fft(a, False)
    fft(b, False)
    
    c = [a[i] * b[i] for i in range(N)]
    fft(c, True)
    for i, n in enumerate(c):
        c[i] = round(c[i].real)
    return c


N, L, M, W = map(int, input().split())
A = [[*map(int, input().split())] for _ in range(M)]
B = [[*map(int, input().split())] for _ in range(M)]

res = []
for i in range(M):
    conv = convolution(A[i], list(reversed(B[i])))
    res.append(conv[L-1:N])

col_sum = [sum(col) for col in zip(*res)]

print(len([n for n in col_sum if n > W]))
profile
한번뿐인 인생! 하고싶은게 너무 많은 뉴비의 deep-dive 현장

0개의 댓글