본 교재에서는 오차 역전파법을 설명하고자 계산 그래프를 이용하여 설명한다. 여기서 말하는 그래프는 그래프 자료구조를 의미하며 이는 여러 Node와 Edge로 구성된다. 아래는 간단한 계산 그래프의 흐름과 예시이다.
문제: 한 개에 100원인 사과를 2개, 한 개에 150원인 귤 세개를 살 때 지불 금액을 구하여라. 단, 소비세 10%가 부과된다.
위와 같이 edge(간선)에는 이전 노드의 계산 결과를, Node(노드)에는 연산 부호를 적고 사과와 귤의 개수, 소비세 등은 변수로 취급하여 원 밖에 표기해 준다.
이처럼 계산을 왼쪽에서 오른쪽으로 진행하는 단계를 순전파(forward propagation)이라고 한다. 반대로 오른쪽에서 왼쪽으로의 전파도 가능한데 이것을 역전파(back propagation)라고 한다.
계산 그래프의 특징은 '국소적 계산'을 전파함으로써 최종 결과를 얻는다는 것이다. 전체에서 어떤 일이 벌어지든 상관없이 자신과 관계된 정보만으로 결과를 출력할 수 있다는 것이다.
그림과 같이 사과 뿐 아니라 여러 식품을 함께 구입하는 경우를 살펴보자. 그림에서 생략하였지만, 이 부분엔 사과 뿐 아니라 생선, 버섯, 돼지고기, 바나나 등을 사과와 같이 그래프로 표현이 되어있을 것이다. 하지만, 사과 입장에서는 그것과 관련이 없다. 그래서 생략이 가능한 것이다. 사과에 대한 국소적 계산은 사과 가격과 앞의 모든 연산 결과로 나온 숫자를 더해주기만 하면 된다. 이 외에는 아무것도 신경 쓸 것이 없다.
그렇다면 계산 그래프의 이점은 무엇일까?
만약 사과 가격이 오른다면 최종 금액에 어떠한 영향이 가는지 궁금하다고 해보자. 이는 '사과 가격에 대한 지불 금액의 미분'을 구하는 문제이다. 지불 금액을 , 사과 값을 라고 하면 수식은 와 같다.
사과 가격에 대한 지불 금액의 미분 값은 계산 그래프에서 역전파를 하여 구할 수 있다.
그림과 같이 역전파는 반대 방향의 화살표로 그려진다. 이는 '국소적 미분'을 전달하고 그 미분 값은 화살표의 아래에 적는다. 위 결과 '사과 가격에 대한 지불 금액의 미분 값'은 2.2라고 할 수있고, 사과가 1원 오르면 최종 금액은 2.2원이 오른다고 할 수 있다.
사과 가격에 대한 미분 값 뿐 아니라, 중간 계산결과도 이용하여 모든 변수의 미분을 구할 수 있다.
이렇게 '국소적인 미분'을 전달하는 원리는 연쇄법칙에 따른 것이다.
라는 계산의 역전파를 그려보면 아래와 같다.
계산 절차는 최종적으로 나온 신호 에 국소적 미분() 을 곱한 후 다음 노드로 전달하는 것이다. 이 국소적 미분은 순전파 때의 의 미분을 의미하므로 만약 이었다면 이다.
그래서 이를 가능하게 해주는 '연쇄법칙' 이란 무엇일까?
연쇄 법칙은 합성 함수의 미분에 대한 성질이며, 합성 함수의 미분은 합성 함수를 구성하는 각 함수의 미분의 곱으로 나타낼 수 있다.
다음 함성함수의 연쇄법칙을 살펴보자.
위에서 (에 대한 의 미분)은 (에 대한 의 미분)와 (에 대한 의 미분)의 곱으로 나타낼 수 있다.
여기서 는 서로 지울 수 있다.
그럼 이제 각각의 국소적 미분을 구해보자.
위 두개의 국소적 미분을 가지고 를 구해보자.
이제 같은 수식을 계산 그래프로 나타내 보면 아래와 같다.
의 노드에서의 역전파부터 시작해보자. 역전파이기 때문에 입력은 가 된다. 이는 1이므로 무시할 수 있다 결과 출력값은 가 된다. 노드에서의 역전파 입력값은 직전 출력값인 이고, 마찬가지로 결과는 순방향에서의 국소 미분값을 곱한 가 된다. 결국 나머지는 소거되고 이는 와 같다.
이렇게 역전파가 하는 일은 연쇄법칙의 원리와 같다.
라는 덧셈 노드의 역전파를 보자.
의 편미분은 이와 같이 계산되는데, 계산 그래프로는 아래와 같이 그려진다.
최종 신호가 인 계산그래프라고 가정을 하면, 위는 그 그래프의 일부이다.
상류에서 전해진 미분이 덧셈 노드를 거치면 결국 1만 곱해져서 출력되어 결국 입력된 값을 그대로 출력하는 특징을 지닌다.
이번엔 라는 곱셈 노드의 역전파를 살펴보자. 편미분은 아래와 같다.
계산 그래프를 보자.
곱셈 노드의 특징은, 순전파 때의 입력값들을 '서로 바꾼값'을 곱해서 하류로 보낸다는 점이다. 순전파 입력값이 였던 엣지는 역전파에선 를 곱해서 하류로 보내고, 순전파 입력값이 였던 엣지는 역전파에선 를 곱해서 하류로 보내는 것을 확인할 수 있다.
위에서 본 사과 문제로 덧셈, 곱셈 노드의 특징 예시를 살펴보자.
소비세를 곱하는 노드의 특징을 다시 보자. 순전파 때 입력값이 200이었던 노드는 1.1을 곱하고, 순전파 때 입력값이 1.1이었던 노드는 200을 곱해서 흘려보낸다.
귤까지 구매하는 계산 그래프에서 각 역전파의 결과값을 확인해보자.
덧셈 노드만 다시 살펴보자. 상류에서 입력된 값이 1.1이고, 하류로 내보내는 값들도 모두 똑같이 1.1인 것을 확인할 수 있다.
위에서 든 예제들을 직접 구현해보자.
class MulLayer:
def __init__(self):
self.x = None
self.y = None
def forward(self, x, y):
self.x = x
self.y = y
out = x * y
return out
def backward(self, dout):
dx = dout * self.y
dy = dout * self.x
return dx, dy
위 코드에서는 곱셈 노드 클래스를 정의한다. 순전파 함수(forward)는 그대로 두 입력값의 곱을 return 해주고, 역전파 함수(backward)는 상류에서 흘러들어온 입력값(dout)에서 출력값으로 엣지에는 를 곱한 값을, 반대로 엣지에는 를 곱한 값을 return 해준다.
apple = 100
apple_num = 2
tax = 1.1
mul_apple_layer = MulLayer()
mul_tax_layer= MulLayer()
apple_price = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_price, tax)
print(price)
출력 결과
그대로 사과 예제로 적용해보자. 먼저 순전파 코드이다. 사과 가격에 사과 개수를 먼저 곱하고, 그 결과에 소비세를 곱해주는 예시이다.
결과로 220을 잘 출력하는 것을 확인할 수 있다.
참고로 220.0이 아닌 220.00000000000003을 출력하는 점은 cs적인 내용이므로 굳이 설명을 첨가하진 않겠다.
아래는 역전파 코드이다.
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)
print(dapple, dapple_num, dtax)
출력 결과
각각 순서대로 사과 가격에 대한, 사과 개수에 대한, 소비세에 대한 총 가격 미분값의 결과를 잘 보여주고 있다.
class AddLayer:
def __init__(self):
pass
def forward(self, x, y):
return out
def backward(self, dout):
dx = dout * 1
dy = dout * 1
return dx, dy
덧셈 클래스이다. 순전파(forward)는 입력받은 를 더한 값을 return, 역전파(backward)는 입력값에 1을 곱한 결과, 즉 입력값 그대로를 다시 출력값으로 return 해준다.
apple = 100
apple_num = 2
orange = 150
orange_num = 3
tax = 1.1
# 계층들
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()
# 순전파
apple_price = mul_apple_layer.forward(apple, apple_num)
orange_price = mul_orange_layer.forward(orange, orange_num)
all_price = add_apple_orange_layer.forward(apple_price, orange_price)
price = mul_tax_layer.forward(all_price, tax)
# 역전파
dprice = 1
dall_price, dtax = mul_tax_layer.backward(dprice)
dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)
dorange, dorange_num = mul_orange_layer.backward(dorange_price)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)
print(f'전체 가격 : {price}')
print(f'사과 개수에 대한 전체 가격의 미분값 : {dapple_num}')
print(f'사과 가격에 대한 전체 가격의 미분값 : {dapple}')
print(f'오렌지 개수에 대한 전체 가격의 미분값 : {dorange_num}')
print(f'오렌지 가격에 대한 전체 가격의 미분값 : {dorange}')
print(f'소비세에 대한 전체 가격의 미분값 : {dtax}')