Pytorch에서의 Backpropagation

용가리·2024년 8월 12일

Pytorch

목록 보기
7/7

반갑습니다.
오늘은 Torch에서의 Backpropagation에 대해 설명해보겠습니다.

Backpropagation이 뭔가요?

Backpropagation은 한국어로 역전파입니다.
머신러닝, 딥러닝에서의 핵심 개념중 하나로, 정답 레이블과 모델의 예측값이 얼마나 다른지에 대한 로스를 모델의 parameters에게 전달하는 역할을 합니다.

사실 Backpropagation은 워낙 핵심개념이기도 해서 다들 알고계실거라 생각합니다.
저도 잘 안다고 생각했는데, 생각보다 설명하려 하니 많이 어렵더라구요.
제가 이해한 만큼 설명해보겠습니다.

    loss = loss_function(y, y_hat) 
    optimizer.zero_grad()
    loss.backward()  
    optimizer.step()

위 코드는 파이토치에서 학습을 진행할 때 사용하는 코드블럭입니다.
대부분의 학습은 위 과정을 따라간다고 생각합니다.
학습 코드야 워낙 다 같으니 별 생각없이 쓰는 분들이 많다고 생각합니다.
저 또한 그랬죠.
근데 코드 한줄 한줄 어떤 역할을 하는지에 대해 설명해보려 하니 많이 어렵더라구요. 이번기회에 정리할 수 있게 돼서 좋습니다.

1. loss = loss_function(y,y_hat)

우리가 제작한 모델에서 나온 출력값을 사전에 정의한 로스함수를 사용하여 정답과 비교합니다.
함수에 대입한 로스의 값이 나옵니다.

2. optimizer.zero_grad()

사전에 정의한 옵티마이저의 기울기를 0으로 초기화합니다.
0으로 초기화하지 않으면, 이전 학습때 각 파라미터에 저장되어 있던 grad가 이번 학습에 영향을 끼치게 되므로, 이를 방지하기 위해서입니다.

3. loss.backward()

로스를 역전파 합니다.
거시적으로는 정말 간단히 설명되지만, 과정이 어떻게 이루어지는지에 대해서 설명을 하려면 어려움이 많습니다.


간단한 함수의 역전파 과정을 수식으로 정리해봤습니다.
역전파는 chain rule을 통해 로스함수를 목표 파라미터에 대해 미분으로 나타낼 수 있도록 하는 과정입니다.
어떤 파라미터를 로스에 대해 미분하고자 하는데 그 파라미터에 로스에 대한 변수가 없다면 미분할 수 없겠죠.
로스에서 목표 파라미터까지 미분할 수 있는 변수들로 나타낼 수 있도록 하는겁니다.

제가 헷갈렸던 부분은 로스는 하나의 값으로 나타나는데, 이것을 어떻게 앞단으로 보내냐 입니다.
앞으로 전달되는 것은 로스의 값이 아닌 로스의 미분식에 대한 대입값인데, 이 부분을 캐치하지 못했네요.
간단한 예를 들어보겠습니다.

BinaryCrossEntropy에 대한 식은

loss=(ylog(yhat)+(1y)log(1yhat))loss = -(y log(yhat)+(1-y)log(1-yhat))

인데, 이를 미분하면

(y/yhat(1y)/(1yhat))-(y/yhat - (1-y)/(1-yhat))

입니다.
그럼 미분 식에 y_hat,y를 넣고 값을 얻습니다.
다른 Weight나 activation_function들도 미분한 식이 있을거고, 대입하면 값이 나오겠습니다.
위 노트에 정리한 활성화함수는 시그모이드 함수입니다.
시그모이드 함수도 미분하면 sigmoid(x)(1-sigmoid(x))가 나오는데, 여기에도 대입하여 값을 구할 수 있습니다.
구한 값들을 바탕으로 체인 룰에 맞게 값들을 곱해주면 역전파를 수행 할 수 있습니다.

ReLU는 x=0에서 미분가능하지 않은데 어떻게 역전파 하죠?

ReLU 함수는 미분값이 음수에서는 0, 양수에서는 1입니다.
하지만 x=0에서는 미분할 수 없죠.
찾아본 결과, 딱 한 점에서의 미분 불가능은 큰 영향은 없다고 합니다.
학습 중에 왠만하면 값이 딱 0에 맞게 나오지 않을 뿐만 아니라, 0이 나온다면 예외처리를 하면 그만이니까요.

4. optimizer.step()

앞서 구한 역전파를 통해 각 파라미터에는 grad가 계산되어 있습니다.
계산된 grad를 optimizer 함수에 넣어 가중치를 갱신합니다.
W1 = W1 - optimizer(grad)가 되는 셈이지요.

'
'
'
'
이렇게 네가지 코드에 대해 좀 자세히 살펴봤습니다.
당연히 알고 있다 생각했는데, 설명하고자 하니 많은 어려움이 있네요.
설명할 수 있는 단계가 배움의 끝이라고 생각하는데, 아직 멀었군요 ㅠ
혹시나 틀린 부분이 있다면 알려주시면 고쳐보도록 하겠습니다.
감사합니다 !

0개의 댓글