Cross Entropy Loss 미분하기

Junseong Park·2024년 8월 18일

Boostcamp AI Tech 7기

목록 보기
1/1
post-thumbnail

사전설명

batch sizeNN이고 label의 수가 CC일 때, Cross Entropy Loss는 다음과 같이 분류 모델을 평가합니다.

L=1Ni=1Nc=1Cyiclogy^ic\mathcal{L} = -\frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C y_{ic}\log{\hat y_{ic}}

yicy_{ic}ii번째 예제의 labelcc라면 11이고, 그렇지 않으면 00입니다.
y^ic\hat y_{ic}는 분류 모델이 ii번째 예제의 labelcc일 확률을 구한 값입니다.

Cross Entropy Loss가 왜 저렇게 구해지는지에 대해서는 Maximum Likelyhood Estimation에 대해서 이야기해야 합니다만, 그 이야기는 너무 길어지니 다른 포스팅에서 소개하도록 하겠습니다.

그런데 Cross Entropy Loss만 미분하면 포스팅이 너무 짧아지니 y^ic\hat y_{ic}까지 한 꺼풀 더 벗겨서 미분해보도록 하겠습니다.
분류 모델의 예측 확률 y^ic\hat y_{ic}는 보통 softmax\text{softmax}라는 함수를 사용해서 구하게 됩니다. xRN×Cx \in \mathbb{R}^{N \times C}일 때 그 식은 아래와 같습니다.

y^ij=softmax(x)ij=exijexi1+exi2++exiC\hat y_{ij} = \text{softmax}(x)_{ij} = \frac{e^{x_{ij}}}{e^{x_{i1}}+e^{x_{i2}}+\cdots + e^{x_{iC}}}

오늘은 간략히 Lx\frac{\partial \mathcal{L}}{\partial x}의 값을 계산해보도록 하겠습니다.

미분

모델의 손실 함수는 대부분 한번에 미분하기 쉽지 않아서 아래와 같이 Chain rule을 이용하여 미분값을 계산합니다.

Lx=Ly^y^x\frac{\partial \mathcal{L}}{\partial x} = \frac{\partial \mathcal{L}}{\partial \hat y} \cdot \frac{\partial \mathcal{\hat y}}{\partial x}

즉, y^\hat y에 대한 L\mathcal{L}의 변화량(Ly^\frac{\partial \mathcal{L}}{\partial \hat y})과 xx에 대한 y^\hat y의 변화량(y^x\frac{\partial \mathcal{\hat y}}{\partial x})의 곱을 구해주면 됩니다. 전자부터 구해보겠습니다. 계산의 편의를 위해 N=1N=1이라고 가정하겠습니다. (연산이 행 단위로만 일어나서 이렇게 가정해도 문제가 없습니다.)

labelii라고 하면, yi=1y_i = 1이고 yj=0(ij)y_j = 0 \,\, (i \neq j)이므로 아래와 같이 정리할 수 있습니다.

L=yilogy^i\mathcal{L} = -y_i \log \hat y_i

위 식을 y^i\hat y_i에 대해서 미분하면 아래와 같습니다.

Ly^i=yiy^iLy^=[y1y^1,y2y^2,,yCy^C]=yy^\begin{aligned} \frac{\partial \mathcal{L}}{\partial \hat y_i} &= - \frac{y_i}{\hat y_i} \\ \frac{\partial \mathcal{L}}{\partial \hat y} &= -\left[\frac{y_1}{\hat y_1}, \frac{y_2}{\hat y_2}, \cdots, \frac{y_C}{\hat y_C} \right] = -\frac{y}{\hat y} \end{aligned}

이제 softmax\text{softmax} 함수를 미분해봅시다. softmax\text{softmax}는 분수함수의 미분 공식을 이용하면 쉽게 미분할 수 있습니다.

(f(x)g(x))=f(x)g(x)f(x)g(x)(g(x))2\left( \frac{f(x)}{g(x)} \right)' = \frac{f'(x)g(x) - f(x)g'(x)}{(g(x))^2}

이때, yj=0(ij)y_j = 0 \,\, (i \neq j)이므로 Ly^i\frac{\partial \mathcal{L}}{\partial \hat y_i}만 고려하면 됩니다. 따라서 y^i\hat y_i에 대한 x1,x2,,xCx_1, x_2, \cdots, x_C의 변화량을 구해봅시다.

y^ixk=(exi)(ex1+ex2++exC)exiexk(ex1+ex2++exC)2\frac{\partial \hat y_i}{\partial x_k} = \frac{(e^{x_i})'(e^{x_1}+e^{x_2}+\cdots+e^{x_C}) - e^{x_i}\cdot e^{x_k}}{(e^{x_1}+e^{x_2}+\cdots+e^{x_C})^2}

이때 (exi)(e^{x_i})'의 값은 k=ik = i라면 exie^{x_i}이며 kik \neq i라면 00입니다. yky_k의 값이 i=ki = k일 때 11, 아닐 때 00임을 이용해서 간단히 ykexiy_k \cdot e^{x_i}라고 나타낼 수 있습니다. 이를 이용하여 식을 정리하면 다음과 같습니다.

y^ixk=yk(exi)(ex1+ex2++exC)exiexk(ex1+ex2++exC)2=yk(exiex1+ex2++exC)(exiex1+ex2++exC)(exkex1+ex2++exC)=yky^iy^iy^k=y^i(yky^k)\begin{aligned} \frac{\partial \hat y_i}{\partial x_k} &= \frac{y_k(e^{x_i})(e^{x_1}+e^{x_2}+\cdots+e^{x_C}) - e^{x_i}\cdot e^{x_k}}{(e^{x_1}+e^{x_2}+\cdots+e^{x_C})^2} \\ &= y_k \cdot \left(\frac{e^{x_{i}}}{e^{x_{1}}+e^{x_{2}}+\cdots + e^{x_{C}}} \right) - \left( \frac{e^{x_{i}}}{e^{x_{1}}+e^{x_{2}}+\cdots + e^{x_{C}}} \right)\left( \frac{e^{x_{k}}}{e^{x_{1}}+e^{x_{2}}+\cdots + e^{x_{C}}} \right) \\ &= y_k\cdot \hat y_i - \hat y_i \cdot \hat y_k \\ &= \hat y_i(y_k - \hat y_k) \end{aligned}

엄청나게 복잡해보이는 수식이었지만 식을 다 정리하고 나니 굉장히 깔끔해진 것을 확인해볼 수 있습니다.

이제 마지막으로 구한 두 수식을 구해지면 L\mathcal{L}에 대한 xx의 변화량을 구할 수 있습니다.

Lxk=Ly^y^xk=Ly^iy^ixk=(1y^i)(y^i(y^kyk))=yky^kLx=[y1y^1,y2y^2,,yCy^C]=yy^\begin{aligned} \frac{\partial \mathcal{L}}{\partial x_k} &= \frac{\partial \mathcal{L}}{\partial \hat y} \cdot \frac{\partial \mathcal{\hat y}}{\partial x_k} \\ &= \frac{\partial \mathcal{L}}{\partial \hat y_i} \cdot \frac{\partial \mathcal{\hat y}_i}{\partial x_k} \\ & = \left(- \frac{1}{\hat y_i} \right) \left( \hat y_i (\hat y_k - y_k) \right) \\ & = y_k - \hat y_k \\ \frac{\partial \mathcal{L}}{\partial x} &= \left[ y_1 - \hat y_1, y_2 - \hat y_2, \cdots, y_C - \hat y_C \right] = y - \hat y \end{aligned}

짜잔! xx에 대한 손실 함수의 변화량은 yy^y - \hat y로 아주 간단하게 정리된다는 것을 알 수 있었습니다!

결론

L=i=1Cyilogy^iy^k=exkex1+ex2++exCLx=yy^\begin{aligned} \mathcal{L} &= -\sum_{i=1}^C y_i \log \hat y_i \\ \hat y_{k} &= \frac{e^{x_{k}}}{e^{x_{1}}+e^{x_{2}}+\cdots + e^{x_{C}}} \\ \frac{\partial \mathcal{L}}{\partial x} &= y - \hat y \end{aligned}

마치며

수식이 많으니까 벨로그 에디터가 무언가 렉이 걸리는 느낌이 드네요... 수식이 많이 나오는 포스팅은 지양하는 편이 좋겠다는 교훈을 얻었습니다. 근데 포스팅 주제로 생각해둔 것들이 대부분 수식을 많이 써야 할 것 같아서 주제를 더 생각해봐야겠네요.

다음 포스팅의 주제는 일단은 간단하게 chain rule 유도하는 법에 대해서 생각중입니다만 바뀔 수도 있을 것 같습니다.

그럼 다음에 뵙겠습니다. 감사합니다.

profile
부스트캠프 AI Tech 7기

0개의 댓글