x
, +
, *
) 를 나타내고, 엣지는 데이터(값)가 흐르는 방향을 나타냄.x ----> (*) ----> y
↑
3
x
에 3을 곱해서 y
를 만든다면, x
, 3
은 입력, *
는 곱셈 노드, y
는 출력.MulLayer
, AddLayer
는 각각 계산 그래프의 노드를 클래스로 구현한 것.
class AddLayer:
def __init__(self):
pass # 덧셈은 중간값 저장 안 해도 되기 때문에 변수 없음
def forward(self, x, y):
out = x + y
return out
def backward(self, dout):
# 상류로부터 받은 미분(dout)을 그대로 x와 y 양쪽으로 전달
dx = dout * 1
dy = dout * 1
return dx, dy
class MulLayer:
def __init__(self):
self.x = None
self.y = None # 역전파 계산을 위해 x, y를 저장해둠
def forward(self, x, y):
self.x = x
self.y = y
out = x * y
return out
def backward(self, dout):
# ∂z/∂x = y, ∂z/∂y = x → 상류 미분에 곱해서 전달
dx = dout * self.y
dy = dout * self.x
return dx, dy
price_apple = 100
apple_num = 2
tax = 1.1
mul_apple_layer = MulLayer() # 사과 개수만큼 곱함
mul_tax_layer = MulLayer() # 세금 계산
price_apple_num = mul_apple_layer.forward(price_apple, apple_num) # 100 * 2 = 200
price_apple_tax = mul_tax_layer.forward(price_apple_num, tax) # 200 * 1.1 = 220.0
price_apple_num
이 먼저 계산되고, 이 값과 tax가 곱해져 price_apple_tax
가 됨.
중간 결과값 저장 덕분에 역전파에서 효율적 계산 가능.
dprice = 1 # 최종 출력값에 대한 미분 1부터 시작
dprice_apple_num, dtax = mul_tax_layer.backward(dprice) # tax 방향으로 200, apple_num 방향으로 1.1
dprice_apple, dapple_num = mul_apple_layer.backward(dprice_apple_num)
dprice = 1
은 ∂L/∂L = 1 (출력에 대한 미분은 자기 자신 기준으로 1)MulLayer
의 역전파는 입력 값을 기억해두었다가, dout
에 곱해서 각각의 입력에 대한 미분을 구함.price_apple = 100
apple_num = 2
price_orange = 150
orange_num = 3
tax = 1.1
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()
price_apple_num = mul_apple_layer.forward(price_apple, apple_num) # 100 * 2 = 200
price_orange_num = mul_orange_layer.forward(price_orange, orange_num) # 150 * 3 = 450
price_all_num = add_apple_orange_layer.forward(price_apple_num, price_orange_num) # 200 + 450 = 650
price_all_tax = mul_tax_layer.forward(price_all_num, tax) # 650 * 1.1 = 715.0
dprice = 1
dprice_all_num, dtax = mul_tax_layer.backward(dprice) # dprice_all_num=1.1, dtax=650
dprice_apple_num, dprice_orange_num = add_apple_orange_layer.backward(dprice_all_num) # 1.1씩 나눠짐
dprice_orange, dorange_num = mul_orange_layer.backward(dprice_orange_num)
dprice_apple, dapple_num = mul_apple_layer.backward(dprice_apple_num)