Softmax-with-Loss 계층 계산그래프

Smiling Sammy·2021년 11월 29일
0
post-thumbnail

선행 지식

계산 그래프, 연쇄법칙, 역전파와 관련된 내용이 궁금하다면, 아래 포스팅을 참조하자.

Softmax-with-Loss 계층

softmax-with-loss 계층의 계산그래프는 아래와 같다.
이전 계층으로부터 입력은 (a1,a2,a3)(a_1, a_2, a_3)이며, Softmax 계층은 (y1,y2,y3)(y_1, y_2, y_3)를 출력한다. 정답 레이블은 (t1,t2,t3)(t_1, t_2, t_3)이며 cross entropy error 계층은 손실 L을 출력한다.

Softmax 계층 계산 그래프 (순전파)

Softmax 수식은 다음과 같다.

yk=exp(ak)i=1nexp(ai)y_k = {exp(a_k) \over \sum_{i=1}^{n}exp(a_i)}

위 식을 바탕으로 Softmax 계산 그래프는 다음과 같다.

Cross Entropy Error 계층 계산 그래프 (순전파, 역전파)

Cross Entropy Error 수식은 다음과 같다.

L=ktklog(yk)L = -\sum_k{t_klog(y_k)}

위 식을 바탕으로 Cross Entropy Error 순전파 계산 그래프는 다음과 같다.

반대로 역전파 그래프는 다음과 같다.

'log'노드의 역전파는 다음 식을 따른다.

y=logxy = logx
δyδx=1x{\delta{y} \over \delta{x}} = {1 \over x}

Softmax 계층 계산 그래프 (역전파)

위에서 구한 cross entropy error 역전파 값이 흘러 들어온다.

'X' 노드에서는 순전파의 입력들을 '서로 바꿔서' 곱한다.
아래 그림에서 위 '/' 노드로 가는 값은 다음과 같은 계산이 이루어진다.

t1y1exp(a1)=t1Sexp(a1)exp(a1)=t1S-{t_1 \over y_1}exp(a_1) = -t_1{S \over exp(a_1)}exp(a_1) = -t_1S

상류로 들어온 값은 (t1S)+(t2S)+(t3S)=S(t1+t2+t3)(-t_1S) + (-t_2S) + (-t_3S) = -S(t_1+t_2+t_3)이다.
순전파의 출력은 1S1\over S이므로 역전파의 출력은 1S1\over S의 미분 값인 1S2- {1\over S^2}를 곱해야한다.
따라서 S(t1+t2+t3)1S2=1S(t1+t2+t3)-S(t_1+t_2+t_3) * - {1\over S^2} = {1\over S}(t_1+t_2+t_3)가 된다.
또한 (t1,t2,t3)(t_1, t_2, t_3)은 정답레이블로 원-핫 벡터로 표현되어 있다.
따라서 t1+t2+t3=1t_1+t_2+t_3=1이 된다.
즉 역전파의 출력은 1S{1\over S}가 된다.

'+' 노드는 입력을 여과 없이 내보낸다.

'x' 노드는 입력을 서로 바꾼 곱셈이다.

exp 노드에서는 exp함수가 미분해도 exp(x)exp(x) 이기 때문에 각 갈래의 입력의 합에 exp(a1)exp(a_1)을 곱한 결과가 역전파이다.
(1St1exp(a1))exp(a1)=y1t1({1 \over S} - {t_1 \over exp(a_1)})exp(a_1) = y_1 - t_1이 된다.

정리

위의 모든 과정을 하나의 그림으로 요약하면 다음과 같다.

참고

profile
Data Scientist, Data Analyst

0개의 댓글