[Lucid] 합성곱 연산의 구현과 최적화

안암동컴맹·2025년 12월 11일

Lucid Development

목록 보기
7/20
post-thumbnail

🏞️ 합성곱 연산의 구현과 최적화

Lucid의 컨볼루션 경로를 설계할 때 가장 먼저 부딪힌 질문은 “범용성과 성능을 동시에 확보할 수 있는가”였다. 1D/2D/3D, stride/padding/dilation/groups 등 다양한 설정을 모두 지원하면서도, NumPy 기반 환경에서 합리적인 속도를 내야 했다. 이 글은 합성곱의 수학적 정의에서 출발해 naive 구현의 병목을 확인하고, Im2Col+GEMM 기반으로 재구성하는 과정에서 마주친 문제와 해결책을 정리한 기록이다.


🧭 이론적 배경 – 합성곱의 정의와 파라미터

2D 합성곱의 기본 형태는 다음과 같다(배치/채널 차원 생략):

y(i,j)=c=0Cin1u=0kh1v=0kw1xc(i ⁣+ ⁣u,j ⁣+ ⁣v)wc(u,v)y(i,j)=\sum_{c=0}^{C_\text{in}-1}\sum_{u=0}^{k_h-1}\sum_{v=0}^{k_w-1} x_c(i\!+\!u,j\!+\!v)\, w_c(u,v)

일반화하면 DD차원 공간에서

y[p,o]=c=0Cin1kKxc[p+k]wo,c[k]y[\mathbf{p}, o]=\sum_{c=0}^{C_\text{in}-1}\sum_{\mathbf{k}\in\mathcal{K}} x_c[\mathbf{p}+\mathbf{k}]\, w_{o,c}[\mathbf{k}]

여기서 K\mathcal{K}는 커널 영역의 모든 좌표다. 실전에서는 다음 파라미터들이 더해진다.

  • stride ss: 출력 좌표 증가량. p\mathbf{p} 대신 sps\mathbf{p}를 사용.
  • padding pp: 입력 주변을 0으로 확장.
  • dilation dd: 커널 간격을 dd배로 벌림(유효 커널 크기: d(k1)+1d(k-1)+1).
  • groups gg: 입력/출력을 gg개 그룹으로 나누어 독립적 합성곱(Depthwise는 g=Cin=Coutg=C_\text{in}=C_\text{out}).

출력 크기(1D 예시)는

O=I+2pd(k1)1s+1O = \Big\lfloor \frac{I + 2p - d(k-1) - 1}{s} \Big\rfloor + 1

이며, DD차원에서도 축별로 동일하게 계산한다.

🐢 Naive 구현과 한계

가장 직접적인 구현은 출력 위치마다 커널 영역을 슬라이스하고 곱-합을 수행하는 중첩 루프다(채널, 공간 차원 모두 루프). 이 방식은

  • 메모리 접근 불규칙: stride/padding/dilation에 따라 비연속 접근이 많아 캐시 효율이 낮다.
  • 벡터화 어려움: 작은 커널과 큰 입력을 다룰 때 SIMD 활용도가 떨어진다.
  • 중복 로드: 서로 다른 출력 위치가 입력의 같은 영역을 반복해서 읽는다.

계산 복잡도는 Im2Col과 동일한 O(NCoutCinkDout)O(N \cdot C_\text{out} \cdot C_\text{in} \cdot k^D \cdot |\text{out}|)이지만, 실제 실행 시간에서 큰 손해를 본다.

간단한 2D naive 의사코드(패딩/stride 반영)는 다음과 같다.

def conv2d_naive(x, w, b=None, stride=(1,1), padding=(0,0)):
    # x: (N, Cin, H, W), w: (Cout, Cin, Kh, Kw)
    N, Cin, H, W = x.shape
    Cout, Cin_w, Kh, Kw = w.shape
    assert Cin == Cin_w

    ph, pw = padding
    sh, sw = stride

    x_pad = pad(x, ((0,0),(0,0),(ph,ph),(pw,pw)))
    H_out = (H + 2*ph - Kh)//sh + 1
    W_out = (W + 2*pw - Kw)//sw + 1
    out = zeros((N, Cout, H_out, W_out))

    for n in range(N):
        for oc in range(Cout):
            for oh in range(H_out):
                for ow in range(W_out):
                    acc = 0
                    for ic in range(Cin):
                        for kh in range(Kh):
                            for kw in range(Kw):
                                ih = oh*sh + kh
                                iw = ow*sw + kw
                                acc += x_pad[n, ic, ih, iw] * w[oc, ic, kh, kw]
                    if b is not None:
                        acc += b[oc]
                    out[n, oc, oh, ow] = acc

    return out

실제 구현에서는 다차원 지원과 dilation, groups까지 더해지면 루프가 더 깊어져 메모리/캐시 효율이 떨어진다. 이 병목을 Im2Col로 풀어내는 것이 이번 글의 핵심 전환점이다.

🚀 Im2Col로의 전환 – 핵심 아이디어

Im2Col + GEMM: 모든 커널 슬라이스를 한 번에 메모리 연속 영역으로 펼친 뒤, 행렬 곱(GEMM)으로 처리한다.

  • 입력 패치 → 컬럼 행렬(col): shape (Nout,CinkD)(N_\text{out}, C_\text{in}\cdot k^D)
  • weight → 행렬: shape (Cout,CinkD)(C_\text{out}, C_\text{in}\cdot k^D)
  • 곱: (Nout×CinkD)(CinkD×Cout)=Nout×Cout(N_\text{out}\times C_\text{in}k^D) \cdot (C_\text{in}k^D \times C_\text{out}) = N_\text{out}\times C_\text{out}

장점:

  • 연속 메모리: 패치를 일렬로 모아 캐시/프리페치 유리.
  • BLAS 최적화: 고성능 GEMM 루틴 활용.
  • 일관된 형태: 1D/2D/3D 및 stride/padding/dilation/groups를 모두 동일한 변환 경로로 처리.

단점(및 감수해야 할 점): Nout×CinkDN_\text{out}\times C_\text{in}k^D 크기의 임시 버퍼 사용. Lucid는 reshape/transpose를 최소화하고 즉시 GEMM에 사용해 이 비용을 완화했다.

🧱 unfold: Im2Col 일반화

def unfold(input_, filter_size, stride, padding, dilation):
    # 출력 크기 계산 (stride/padding/dilation 반영)
    out_dims = [...]
    # 입력 패딩
    x = lucid.pad(input_, [(0,0), (0,0), *[(p,p) for p in padding]])

    # 모든 커널 오프셋 순회
    offsets = itertools.product(*[range(k) for k in filter_size])
    patches = []

    for off in offsets:
        sl = [slice(None), slice(None)]

        for d in range(D):
            start = off[d] * dilation[d]
            end = start + stride[d] * out_dims[d]
            sl.append(slice(start, end, stride[d]))

        p = x[tuple(sl)].unsqueeze(axis=2)
        patches.append(p)

    # (N_out, C*k^D) 평탄화
    col = lucid.concatenate(patches, axis=2)
    return col.reshape((N_out, C_filt))

파라미터 처리

  • filter_size/stride/padding/dilation 길이가 모두 DD인지 확인.
  • 출력 크기: Ii+2pidi(ki1)1si+1\left\lfloor \dfrac{I_i + 2p_i - d_i(k_i-1) - 1}{s_i} \right\rfloor + 1.
  • dilation은 오프셋 off[d] * dilation[d]로 반영.
  • shape 정렬: [N, C, *filter_size, *out_dims] → (N_out, C*k^D).

수학적 관점

패치 좌표 k=(k1,,kD)\mathbf{k}=(k_1,\dots,k_D), 출력 좌표 p=(p1,,pD)\mathbf{p}=(p_1,\dots,p_D), stride s\mathbf{s}, dilation d\mathbf{d}, padding ppad\mathbf{p}_{\text{pad}}에 대해 실제 입력 인덱스는

i=ps+kdppad\mathbf{i} = \mathbf{p}\odot\mathbf{s} + \mathbf{k}\odot\mathbf{d} - \mathbf{p}_{\text{pad}}

unfold는 모든 k\mathbf{k}를 열거해 (Nout,CinkD)(N_\text{out},\, C_\text{in}\cdot k^D) 행렬을 구성한다. 행 인덱스는 p\mathbf{p}를 1D로 나열한 것이고, 열 인덱스는 (c,k)(c,\mathbf{k})를 평탄화한 것이다.

인덱스 → 슬라이스 매핑

start = off[d] * dilation[d], end = start + stride[d] * out_dims[d], slice(start, end, stride[d])는 위 식에서 ps\mathbf{p}\odot\mathbf{s}를 슬라이스 스텝으로 구현한 것. 각 off 루프는 커널 좌표 k\mathbf{k}를 의미하고, 슬라이스 결과에 unsqueeze(axis=2)로 커널 축을 추가한 뒤 concat해 열 방향으로 쌓는다.

개발 중 만난 이슈

오프셋을 잘못 계산하면 stride/dilation 조합에서 음수 또는 0 크기 출력이 발생. 모든 축에 대해 유효성 검사를 추가했고, pad_config를 [batch, channel]에 대해 항상 (0,0)으로 두어 채널/배치 패딩 실수를 방지했다. 또한 슬라이스 스텝을 활용해 NumPy/MLX 내부 벡터화를 유도하도록 수정했다.

✖️ _im2col_conv: 핵심 곱셈

def _im2col_conv(input_, weight, bias, stride, padding, dilation, groups=1):
    # 채널/그룹 일관성 검사
    col = unfold(input_, filter_size, stride, padding, dilation)
    weight_rs = weight.reshape(groups, C_out_g, C_in_g * prod_filter)
    col_rs = col.reshape(N_out, groups, C_in_g * prod_filter)

    outs = []
    for g in range(groups):
        c_g = col_rs[:, g, :]   # (N_out, Cin_g*k^D)
        w_g = weight_rs[g]      # (Cout_g, Cin_g*k^D)

        out = lucid.einops.einsum("nk,ok->no", c_g, w_g)  # GEMM
        outs.append(out)

    out_cat = lucid.concatenate(outs, axis=1)
    out_nd = out_cat.reshape([N] + out_dims + [C_out]).transpose(perm)

    if bias is not None:
        out_nd += bias.reshape((1, C_out) + (1,) * D)

    return out_nd

파라미터 반영

  • groups: 입력/출력 채널을 groups로 나눠 독립 계산 후 concat. Depthwise는 groups=C_in=C_out, weight reshape로 자동 처리.
  • stride/padding/dilation: 모두 unfold 단계에서 반영(슬라이스 간격, pad, dilation 오프셋).
  • kernel_size: filter_size에서 prod_filter를 구해 weight/col reshape에 사용.
  • bias: 출력 텐서 shape에 맞춰 reshape 후 더한다.

성능 메모: group이 많을수록 작은 GEMM을 여러 번 실행한다. BLAS에 따라 작은 배치에서 오히려 비효율이 될 수 있으나, 캐시 친화성과 구현 단순성을 우선했다. 추후 Winograd/FFT 같은 특수 커널도 고려했으나, CPU/NumPy 환경에서는 Im2Col+GEMM이 가장 균형 잡힌 선택이었다.

🧭 1D/2D/3D 분기와 API 표준화

_im2col_conv는 공간 차원 DD에 대해 일반화되어 있다. 상위에서 conv1d/2d/3d를 제공할 때는:

  • kernel_size/stride/padding/dilation을 축 길이 튜플로 정규화.
  • 입력/가중치 차원 검증(최소 3D: N,C,Spatial).
  • groups/채널 배수 조건 확인.

모듈(nn.Conv1d/2d/3d)은 파라미터를 소유하고, forward는 이 functional 호출로 일관성을 유지한다. 이렇게 하면 backend 확장 시에도 functional만 교체하면 된다.

📊 복잡도와 메모리 트레이드오프

  • 시간: Im2Col 비용이 추가되나, GEMM 가속으로 전체 시간은 naive 대비 대체로 빠르다. 커널이 크거나 stride가 작은 경우 이득이 더 크다.
  • 메모리: 컬럼 버퍼가 필요하다. Lucid는 col을 바로 reshape해 group-wise GEMM에 사용해 불필요한 복사를 최소화했다.
  • groups/dilation: group 증가 시 작은 GEMM 다수 호출 → BLAS 오버헤드 가능. dilation 증가 시 패치 수 증가 → 메모리 접근 증가. 하지만 구조는 동일해 유지보수가 용이하다.

개발 당시 고민: 메모리-속도 트레이드오프를 어떻게 조절할지. Winograd/FFT 고려는 CPU 백엔드와 NumPy 환경에서 이점이 크지 않아 제외했다. 대신 코드 단순성과 범용성을 유지했다.

🧮 수학적 관점에서 본 Im2Col의 이점

행렬 곱 Y=XWY = XW^\top에서 XX(Nout,CinkD)(N_\text{out}, C_\text{in}k^D), WW(Cout,CinkD)(C_\text{out}, C_\text{in}k^D)이다. naive와 FLOPs는 동일하지만:

  • 연속성: Im2Col은 XX를 연속 메모리로 만들어 스트라이드가 커질수록 흩어지는 접근을 모은다.
  • 재사용성: 동일 입력 패치를 여러 출력 위치에서 재사용하는 대신 한 번의 변환으로 공유.
  • 벡터화: BLAS는 캐시 블로킹, SIMD, 스레딩이 이미 최적화되어 있어 별도 커널 튜닝이 필요 없다.

결과적으로 하드웨어 친화성이 올라가 실제 실행 시간이 줄어든다. 이는 특히 큰 커널(예: 5×5, 7×7)과 작은 stride에서 두드러진다.

😩 구현 여정에서의 난점과 해결

  1. 출력 크기 불일치: stride/dilation/padding 조합에서 음수 또는 0 출력이 나오는 버그. → 각 축마다 유효성 검사 추가, 에러 메시지로 축/값 명시.
  2. group/channel reshape 오류: weight reshape가 CinC_\text{in}, CoutC_\text{out} 배수 조건을 지키지 않을 때 silent miscompute. → 명시적 검사로 early fail.
  3. 메모리 사용량 우려: Im2Col 버퍼 크기. → col을 바로 reshape 후 einsum에 사용, 중간 사본 방지.
  4. 다차원 일반화: 1D/2D/3D 공통 코드에서 permute/reshape 순서를 잘못 두면 채널과 공간이 섞임. → permutation 리스트를 수식으로 정의하고, 테스트 케이스에서 permutation 결과 shape를 검증.

✅ 정리 및 다음 단계

Lucid의 합성곱은 다음 원칙을 따른다.

  • 형태 변환 → GEMM: 모든 공간/채널 파라미터를 unfold에서 처리한 뒤 행렬 곱으로 환원.
  • 파라미터 일반화: kernel_size/stride/padding/dilation/groups를 튜플로 정규화해 DD차원 공통 경로 유지.
  • 역전파 단순화: forward가 primitive와 reshape로 구성돼 autodiff가 그대로 역전파 처리. transpose conv도 동일 경로 재사용.

이 구조와 코드(위 링크)를 참고하면 임의 차원과 설정을 지원하는 Im2Col+GEMM 컨볼루션을 재현할 수 있다. 다음 단계에서는 모듈 레벨 API 구축을 위한 nn.Module의 구현 과정을 상세하게 기술할 예정이다.

profile
Korea Univ. Computer Science & Engineering

0개의 댓글