본 글은 1D CNN에서 수식을 기반으로 한 backpropagation 과정을 간단한 예제를 포함합니다.

1. 간단한 1D CNN 모델

구성

  • 입력: x=[x1,x2,x3]x = [x_1, x_2, x_3] (1채널, 길이 3)
  • 커널: w=[w1,w2]w = [w_1, w_2], bias: bb
  • 출력: z=w1x1+w2x2+bz = w_1 x_1 + w_2 x_2 + b
  • 활성화 함수: ReLU
  • 손실 함수: MSE (Mean Squared Error)
  • 타겟: yy

2. 순전파 (Forward pass)

Conv1D + Bias

입력:
x=[x1,x2]x = [x_1, x_2],
가중치: w=[w1,w2]w = [w_1, w_2],
바이어스: bb

Convolution 결과 (합성곱):

z=w1x1+w2x2+bz = w_1 x_1 + w_2 x_2 + b

ReLU

a=ReLU(z)=max(0,z)a = \text{ReLU}(z) = \max(0, z)

출력층 (간단히 선형):

y^=a\hat{y} = a

3. 손실 함수 (MSE)

L=12(y^y)2\mathcal{L} = \frac{1}{2} (\hat{y} - y)^2

4. 역전파 (Backpropagation)

Step 1: 손실 w.r.t 출력

Ly^=y^y\frac{\partial \mathcal{L}}{\partial \hat{y}} = \hat{y} - y

Step 2: 출력 w.r.t ReLU

y^a=1Ly^y^a=La=y^y\frac{\partial \hat{y}}{\partial a} = 1 \quad \Rightarrow \quad \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a} = \frac{\partial \mathcal{L}}{\partial a} = \hat{y} - y

Step 3: ReLU w.r.t 합성곱 결과 zz

az={1if z>00otherwiseLz=(y^y)1z>0\frac{\partial a}{\partial z} = \begin{cases} 1 & \text{if } z > 0 \\ 0 & \text{otherwise} \end{cases} \quad \Rightarrow \quad \frac{\partial \mathcal{L}}{\partial z} = (\hat{y} - y) \cdot \mathbf{1}_{z > 0}

Step 4: 합성곱 결과 w.r.t 가중치, 입력, 바이어스

  • zw1=x1\frac{\partial z}{\partial w_1} = x_1
  • zw2=x2\frac{\partial z}{\partial w_2} = x_2
  • zb=1\frac{\partial z}{\partial b} = 1
  • zx1=w1\frac{\partial z}{\partial x_1} = w_1
  • zx2=w2\frac{\partial z}{\partial x_2} = w_2

따라서:

  • Lw1=Lzx1\frac{\partial \mathcal{L}}{\partial w_1} = \frac{\partial \mathcal{L}}{\partial z} \cdot x_1
  • Lw2=Lzx2\frac{\partial \mathcal{L}}{\partial w_2} = \frac{\partial \mathcal{L}}{\partial z} \cdot x_2
  • Lb=Lz\frac{\partial \mathcal{L}}{\partial b} = \frac{\partial \mathcal{L}}{\partial z}

5. 한 줄 정리

Lwi=(y^y)1z>0xi\frac{\partial \mathcal{L}}{\partial w_i} = (\hat{y} - y) \cdot \mathbf{1}_{z > 0} \cdot x_i
Lb=(y^y)1z>0\frac{\partial \mathcal{L}}{\partial b} = (\hat{y} - y) \cdot \mathbf{1}_{z > 0}

이렇게 **체인 룰(chain rule)**을 통해 각 단계별로 미분값을 곱하면서 손실에 대한 가중치의 그래디언트를 구하고, 그걸로 가중치를 업데이트합니다.

profile
AI developer

0개의 댓글