
Lucid의 컨볼루션 경로를 설계할 때 가장 먼저 부딪힌 질문은 “범용성과 성능을 동시에 확보할 수 있는가”였다. 1D/2D/3D, stride/padding/dilation/groups 등 다양한 설정을 모두 지원하면서도, NumPy 기반 환경에서 합리적인 속도를 내야 했다. 이 글은 합성곱의 수학적 정의에서 출발해 naive 구현의 병목을 확인하고, Im2Col+GEMM 기반으로 재구성하는 과정에서 마주친 문제와 해결책을 정리한 기록이다.
2D 합성곱의 기본 형태는 다음과 같다(배치/채널 차원 생략):
일반화하면 차원 공간에서
여기서 는 커널 영역의 모든 좌표다. 실전에서는 다음 파라미터들이 더해진다.
출력 크기(1D 예시)는
이며, 차원에서도 축별로 동일하게 계산한다.
가장 직접적인 구현은 출력 위치마다 커널 영역을 슬라이스하고 곱-합을 수행하는 중첩 루프다(채널, 공간 차원 모두 루프). 이 방식은
계산 복잡도는 Im2Col과 동일한 이지만, 실제 실행 시간에서 큰 손해를 본다.
간단한 2D naive 의사코드(패딩/stride 반영)는 다음과 같다.
def conv2d_naive(x, w, b=None, stride=(1,1), padding=(0,0)):
# x: (N, Cin, H, W), w: (Cout, Cin, Kh, Kw)
N, Cin, H, W = x.shape
Cout, Cin_w, Kh, Kw = w.shape
assert Cin == Cin_w
ph, pw = padding
sh, sw = stride
x_pad = pad(x, ((0,0),(0,0),(ph,ph),(pw,pw)))
H_out = (H + 2*ph - Kh)//sh + 1
W_out = (W + 2*pw - Kw)//sw + 1
out = zeros((N, Cout, H_out, W_out))
for n in range(N):
for oc in range(Cout):
for oh in range(H_out):
for ow in range(W_out):
acc = 0
for ic in range(Cin):
for kh in range(Kh):
for kw in range(Kw):
ih = oh*sh + kh
iw = ow*sw + kw
acc += x_pad[n, ic, ih, iw] * w[oc, ic, kh, kw]
if b is not None:
acc += b[oc]
out[n, oc, oh, ow] = acc
return out
실제 구현에서는 다차원 지원과 dilation, groups까지 더해지면 루프가 더 깊어져 메모리/캐시 효율이 떨어진다. 이 병목을 Im2Col로 풀어내는 것이 이번 글의 핵심 전환점이다.
Im2Col + GEMM: 모든 커널 슬라이스를 한 번에 메모리 연속 영역으로 펼친 뒤, 행렬 곱(GEMM)으로 처리한다.
col): shape 장점:
단점(및 감수해야 할 점): 크기의 임시 버퍼 사용. Lucid는 reshape/transpose를 최소화하고 즉시 GEMM에 사용해 이 비용을 완화했다.
unfold: Im2Col 일반화def unfold(input_, filter_size, stride, padding, dilation):
# 출력 크기 계산 (stride/padding/dilation 반영)
out_dims = [...]
# 입력 패딩
x = lucid.pad(input_, [(0,0), (0,0), *[(p,p) for p in padding]])
# 모든 커널 오프셋 순회
offsets = itertools.product(*[range(k) for k in filter_size])
patches = []
for off in offsets:
sl = [slice(None), slice(None)]
for d in range(D):
start = off[d] * dilation[d]
end = start + stride[d] * out_dims[d]
sl.append(slice(start, end, stride[d]))
p = x[tuple(sl)].unsqueeze(axis=2)
patches.append(p)
# (N_out, C*k^D) 평탄화
col = lucid.concatenate(patches, axis=2)
return col.reshape((N_out, C_filt))
filter_size/stride/padding/dilation 길이가 모두 인지 확인. off[d] * dilation[d]로 반영. [N, C, *filter_size, *out_dims] → (N_out, C*k^D).패치 좌표 , 출력 좌표 , stride , dilation , padding 에 대해 실제 입력 인덱스는
unfold는 모든 를 열거해 행렬을 구성한다. 행 인덱스는 를 1D로 나열한 것이고, 열 인덱스는 를 평탄화한 것이다.
start = off[d] * dilation[d], end = start + stride[d] * out_dims[d], slice(start, end, stride[d])는 위 식에서 를 슬라이스 스텝으로 구현한 것. 각 off 루프는 커널 좌표 를 의미하고, 슬라이스 결과에 unsqueeze(axis=2)로 커널 축을 추가한 뒤 concat해 열 방향으로 쌓는다.
오프셋을 잘못 계산하면 stride/dilation 조합에서 음수 또는 0 크기 출력이 발생. 모든 축에 대해 유효성 검사를 추가했고, pad_config를 [batch, channel]에 대해 항상 (0,0)으로 두어 채널/배치 패딩 실수를 방지했다. 또한 슬라이스 스텝을 활용해 NumPy/MLX 내부 벡터화를 유도하도록 수정했다.
_im2col_conv: 핵심 곱셈def _im2col_conv(input_, weight, bias, stride, padding, dilation, groups=1):
# 채널/그룹 일관성 검사
col = unfold(input_, filter_size, stride, padding, dilation)
weight_rs = weight.reshape(groups, C_out_g, C_in_g * prod_filter)
col_rs = col.reshape(N_out, groups, C_in_g * prod_filter)
outs = []
for g in range(groups):
c_g = col_rs[:, g, :] # (N_out, Cin_g*k^D)
w_g = weight_rs[g] # (Cout_g, Cin_g*k^D)
out = lucid.einops.einsum("nk,ok->no", c_g, w_g) # GEMM
outs.append(out)
out_cat = lucid.concatenate(outs, axis=1)
out_nd = out_cat.reshape([N] + out_dims + [C_out]).transpose(perm)
if bias is not None:
out_nd += bias.reshape((1, C_out) + (1,) * D)
return out_nd
파라미터 반영
groups: 입력/출력 채널을 groups로 나눠 독립 계산 후 concat. Depthwise는 groups=C_in=C_out, weight reshape로 자동 처리. stride/padding/dilation: 모두 unfold 단계에서 반영(슬라이스 간격, pad, dilation 오프셋). kernel_size: filter_size에서 prod_filter를 구해 weight/col reshape에 사용. bias: 출력 텐서 shape에 맞춰 reshape 후 더한다.성능 메모: group이 많을수록 작은 GEMM을 여러 번 실행한다. BLAS에 따라 작은 배치에서 오히려 비효율이 될 수 있으나, 캐시 친화성과 구현 단순성을 우선했다. 추후 Winograd/FFT 같은 특수 커널도 고려했으나, CPU/NumPy 환경에서는 Im2Col+GEMM이 가장 균형 잡힌 선택이었다.
_im2col_conv는 공간 차원 에 대해 일반화되어 있다. 상위에서 conv1d/2d/3d를 제공할 때는:
kernel_size/stride/padding/dilation을 축 길이 튜플로 정규화.모듈(nn.Conv1d/2d/3d)은 파라미터를 소유하고, forward는 이 functional 호출로 일관성을 유지한다. 이렇게 하면 backend 확장 시에도 functional만 교체하면 된다.
col을 바로 reshape해 group-wise GEMM에 사용해 불필요한 복사를 최소화했다. 개발 당시 고민: 메모리-속도 트레이드오프를 어떻게 조절할지. Winograd/FFT 고려는 CPU 백엔드와 NumPy 환경에서 이점이 크지 않아 제외했다. 대신 코드 단순성과 범용성을 유지했다.
행렬 곱 에서 는 , 는 이다. naive와 FLOPs는 동일하지만:
결과적으로 하드웨어 친화성이 올라가 실제 실행 시간이 줄어든다. 이는 특히 큰 커널(예: 5×5, 7×7)과 작은 stride에서 두드러진다.
col을 바로 reshape 후 einsum에 사용, 중간 사본 방지. Lucid의 합성곱은 다음 원칙을 따른다.
unfold에서 처리한 뒤 행렬 곱으로 환원.kernel_size/stride/padding/dilation/groups를 튜플로 정규화해 차원 공통 경로 유지.이 구조와 코드(위 링크)를 참고하면 임의 차원과 설정을 지원하는 Im2Col+GEMM 컨볼루션을 재현할 수 있다. 다음 단계에서는 모듈 레벨 API 구축을 위한 nn.Module의 구현 과정을 상세하게 기술할 예정이다.