Neural Network 오차 역전파(back propagation)

Surf in Data·2022년 4월 3일
0

deep learning

목록 보기
3/9
post-thumbnail

backpropagation(역전파)

심층신경망에서는 신경망의 크기가 너무 커지고, 입력이나 출력의 개수가 많아지면서 graient descent의 방법은 사실상 불가능하다. 따라서 Neural Network에서는 출력층부터 역순으로 Gradient 를 전달하여 전체 Layer의 가중치를 Update하는 방식을 사용한다.

간단한 코드로 back propagation 코드 쿠구현을 해보기 위해서 다음과 같은 가정을 해보겠다.
이진분류 사람이냐 아니냐를 판별하는 모델이라고 하면 판별에 쓰이는 피처는 2가지 x1, x2 두가지이고 은닉층의 활성화함수는 Relu함수 마지막 output에서의 활성화함수는 이진분류이므로 sigmoid함수를 사용했고 sigmoid함수를 통과한 예측값을 z_hat, 실제값을 z라하겠다.

  • relu함수 구현
def relu(x):
    if x > 0:
        x = x
    else:
        x = 0
    
    return x
  • sigmoid함수 구현
def sigmoid(x):
    return 1 / (1 +np.exp(-x))
  • 이진분류의 target값:
    피처 2개
x1, x2, z = 0.7, 0.3, 1.0
print(x1, x2, z)
0.7 0.3 1.0

back propagation을 통해 많은 가중치중 w11의 가중치를 업데이트 하는 함수를 구현

def back_propagation(x1, x2, z, iterations):
    
    #learning_rate
    learning_rate = 0.5
    
    #ramdom weight 값
    np.random.seed(2)
    w1 = round(np.random.rand(), 3)
    w11 = round(np.random.rand(), 3)
    w12 = round(np.random.rand(), 3)
    w21 = round(np.random.rand(), 3)
    w2 = round(np.random.rand(), 3)
    w22 = round(np.random.rand(), 3)
    
    weights = pd.DataFrame({"초기weight": ["w1", "w11", "w12", "w21", "w2", "w22"], "값": [w1, w11, w12, w21, w2, w22]})
    weights.style.hide_index()
    print(weights)
    print("---------------------------")
    
    for iteration in range(iterations):
        #propagartion
        f1 = x1*w1 + x2*w2
        f2 = x1*w12 + x2*w2

        z1 = relu(f1)
        z2 = relu(f2)

        f = z1*w11 + z2+w22
        z_hat = sigmoid(f)

        CE = -(z*np.log(z_hat)+(1-z)*np.log(z_hat)) #이진분류에서의 오차값 cross enthrophy
        reult = pd.DataFrame()
        print(f"iteration{iteration + 1} propagation완료")
        print("z_hat", z_hat, "CE", CE)
        
        ## W11_UPDATE
        print("업데이트 되기전 w11: ", w11)
        print("w11 업데이트")
        w11 = w11 + learning_rate*(1/z_hat)*sigmoid(f)*(1-sigmoid(f))*z1
        print("업데이트 된후의 w11: ", w11)
        print("---------------------------")
back_propagation(x1, x2, z, 3)
  초기weight      값
0       w1  0.436
1      w11  0.026
2      w12  0.550
3      w21  0.435
4       w2  0.420
5      w22  0.330
---------------------------
iteration1 propagation완료
z_hat 0.7010307863243558 CE 0.35520347518814194
업데이트 되기전 w11:  0.026
w11 업데이트
업데이트 된후의 w11:  0.09045776246846889
---------------------------
iteration2 propagation완료
z_hat 0.7068233374476314 CE 0.3469745206215468
업데이트 되기전 w11:  0.09045776246846889
w11 업데이트
업데이트 된후의 w11:  0.15366665091475956
---------------------------
iteration3 propagation완료
z_hat 0.7124393640689541 CE 0.3390604735219194
업데이트 되기전 w11:  0.15366665091475956
w11 업데이트
업데이트 된후의 w11:  0.21566472402149306
---------------------------

아래의 그림을 보면 첫번째 propagation 을 통한 CE의 값은 0.3552이다.

back propagation을 통해 w11의 가중치를 0.026에서 0.904로 업데이트 해주었고 CE는 0.3552에서 0.3469로 감소하였다.

bcak propagation 에서 w11의 가중치를 업데이트 하는 과정을 수식으로 이해하기

profile
study blog

0개의 댓글