역전파 이해
이해를 위한 인공 신경망은
두 개의 입력 - 두 개의 은닉층 뉴런 - 두 개의 출력층 누련
이며 활성화 함수는 시그모이드를 사용
- 순전파(Forward Propagation)
파란 숫자는 입력값, 빨간 숫자는 각 가중치의 값을 의미
z1=w1x1+w2x2=0.3×0.1+0.25×0.2=0.08
z2=w3x1+w4x2=0.4×0.1+0.35×0.2=0.11
z1,z2는 은닉층 뉴런에서 시그모이드 함수를 지나고 결과값은 은닉층 뉴런의 최종 출력 값 h1,h2이다.
h1=sigmoid(z1)=0.51998934
h2=sigmoid(z2)=0.52747230
이는 다시 출력층 뉴런의 입력값이 되고
z3=w5h1+w6h2=0.45×h1+0.4×h2=0.44498412
z4=w7h1+w8h2=0.7×h1+0.6×h2=0.68047592
이후 출력층에서 시그모이드 함수를 지난 값은 최종 출력값(예측값)
o1=sigmoid(z3)=0.60944600
o2=sigmoid(z4)=0.66384491
선택한 손실함수(여기서는 MSE)에 따라 오차를 계산하고 전체 오차를 구한다.
Eo1=21(targeto1−outputo1)2=0.02193381
Eo2=21(targeto2−outputo2)2=0.00203809
Etotal=Eo1+Eo2=0.02397190
- 역전파 1단계(BackPropagation Step1)
출력층에서 입력층 방향으로 계산하며 가중치를 업데이트
업데이트할 가중치는 w5,w6,w7,w8
총 4개
미분의 연쇄 법칙을 통해 ∂w5∂Etotal를 계산
∂w5∂Etotal=∂o1∂Etotal×∂z3∂o1×∂w5∂z3
(1) 첫째 항
Etotal의 값은 전체 오차값으로 식은
Etotal=21(targeto1−outputo1)2+21(targeto2−outputo2)2
이에 ∂o1∂Etotal=2×21(targeto1−outputo1)2−1×(−1)+0
∂o1∂Etotal=−(targeto1−outputo1)=−(0.4−0.60944600)=0.20944600
(2) 둘째 항
시그모이드 함수의 미분은 f(x)×(1−f(x))이고
따라서 시그모이드 함수 출력값인 o1은
∂z3∂o1=o1×(1−o1)=0.60944600(1−0.60944600)=0.23802157
(3) 셋째 항
∂w5∂z3=h1=0.51998934
우변 모든 항 계산을 곱해주면
∂w5∂Etotal=0.20944600×0.23802157×0.51998934=0.02592286
경사 하강법에 따라 가중치를 업데이트 학습률은 0.5로 가정
w5+=w5−α∂w5∂Etotal=0.45−0.5×0.02592286=0.43703857
같은 원리
∂w6∂Etotal=∂o1∂Etotal×∂z3∂o1×∂w6∂z3→w6+=0.38685205
∂w7∂Etotal=∂o2∂Etotal×∂z4∂o2×∂w7∂z4→w7+=0.69629578
∂w8∂Etotal=∂o2∂Etotal×∂z4∂o2×∂w8∂z4→w8+=0.59624247
- 역전파 2단계
1 단계 이후 입력층 방향으로
층이 많다면 반복 수행한다.
∂w1∂Etotal=∂h1∂Etotal×∂z1∂h1×∂w1∂z1
(1) 첫째 항
∂h1∂Etotal=∂h1∂Eo1+∂h1∂Eo2
∂h1∂Eo1=∂z3∂Eo1×∂h1∂z3=∂o1∂Eo1×∂z3∂o1×∂h1∂z3
=−(targeto1−outputo1)×o1×(1−o1)×w5
=0.20944600×0.23802157×0.45=0.02243370
∂h1∂Eo2=∂z4∂Eo2×∂h1∂z4=∂o2∂Eo2×∂z4∂o2×∂h1∂z4=0.00997311
∂h1∂Etotal=0.02243370+0.00997311=0.03240681
(2) 둘째 항
∂z1∂h1=h1×(1−h1)=0.51998934(1−0.51998934)=0.24960043
(3) 셋째 항
∂w1∂z1=x1=0.1
즉 ∂w1∂Etotal=0.03240681×0.24960043×0.1=0.00080888
경사 하강법 업데이트
w1+=w1−α∂w1∂Etotal=0.3−0.5×0.00080888=0.29959556
이외에도
∂w2∂Etotal=∂h1∂Etotal×∂z1∂h1×∂w2∂z1→w2+=0.24919112
∂w3∂Etotal=∂h2∂Etotal×∂z2∂h2×∂w3∂z2→w3+=0.39964496
∂w4∂Etotal=∂h2∂Etotal×∂z2∂h2×∂w4∂z2→w4+=0.34928991
- 결과 확인
업데이트로 오차 감소가 있는지 확인
z1=w1x1+w2x2=0.29959556×0.1+0.24919112×0.2=0.07979778
z2=w3x1+w4x2=0.39964496×0.1+0.34928991×0.2=0.10982248
h1=sigmoid(z1)=0.51993887
h2=sigmoid(z2)=0.52742806
z3=w5h1+w6h2=0.43703857×h1+0.38685205×h2=0.43126996
z4=w7h1+w8h2=0.69629578×h1+0.59624247×h2=0.67650625
o1=sigmoid(z3)=0.60617688
o2=sigmoid(z4)=0.66295848
Eo1=21(targeto1−outputo1)2=0.02125445
Eo2=21(targeto2−outputo2)2=0.00198189
Etotal=Eo1+Eo2=0.02323634
기존의 전체 오차가 0.02397190였으므로 1번의 역전파로 오차가 감소한 것을 확인할 수 있다