기본적으로 DeZero는 numpy ndarray로 연산을 수행한다.
따라서 Variable 내에서 ndarray가 아닌 data type에 대해 예외처리한다.
import numpy as np
def as_array(x): #제곱등의 넘파이 연산시 ndarray에서 np float 64등으로 바뀌게 되는데 이를 방지하기 위해서 as_array선언
#Function 클래스의 Forward결과를 Variable 객체로 상속할때 사용한다. Forward에 사용되는 넘파이 연산이 스칼라값을 반환하기 때문
if np.isscalar(x):
return np.array(x)
else:
return x
class Variable():
def __init__(self,data):
self.data = data
if data is not None:
if not isinstance(data,np.ndarray): #ndarray타입만 지원하기 위한 예외처리
raise TypeError(f"{type(data)}는 지원하지 않습니다.")
self.grad = None
self.creator = None #계산 그래프상 이전 함수를 follow up 하기 위한 클래스 변수
def set_creator(self,func):
self.creator = func #해당 변수를 만든 함수 객체를 지정해주는 부분
def backward(self):
if self.grad is None: #dy/dy는 무조건 1인데 이를 연산시에 지정하고 싶지 않으므로 이렇게 지정해둔다
self.grad = np.ones_like(self.data) #ones_like를 쓰는 이유는 데이터 타입까지 따라가기 위해서
#위 코드의 구체적 설명 : 현재 코드는 y라는 아웃풋을 가지고 함수의 backward()를 이용해 x라는 인풋의 grad를 구하면서 역전파를 수행
#근데 가장 먼저 알아야하는 grad인 dy/dy의 경우 사실 무조건 1이다. 이를 따로 지정하지 않으면 y라는 객체를 처음 생성했을 때 grad가 None이 되는데 (Variable의 init 참조)
#이 값이 무조건 1이 되어야하므로 None인 경우에 grad를 1로 지정한다. 이 dy/dy에 해당하는 grad를 제외하고는 chain rule에 의해 None이 아닌 가지게 되므로 원하는 역할을 수행한다.
funcs = [self.creator] #함수를 리스트로 저장해서 역전파 구현
while funcs: #이전에 연산한 함수가 있느 동안 계속 실행
f = funcs.pop() #존재하는 가장 최근의 연산을 pop해서 지정
print(f)
x = f.input #해당 함수로의 인풋값
y = f.output #해당 함수에서의 결과값 지정
x.grad = f.backward(y.grad) #Function 객체에서의 backward
if x.creator is not None: #만약 이전 함수가 있다면 즉 추가로 역전파할 함수가 있다면
funcs.append(x.creator) #해당 함수를 funcs리스트에 추가
class Function: #순전파와 역전파를 구현하는 클래스
def __call__(self,input): #Variable 객체를 인자로 받는다
x = input.data
y = self.forward(x) #forward의 결과를 y로 저장
output = Variable(as_array(y)) #이를 ndarray로 바꾸고 Variable객체 output 생성
output.set_creator(self) #이 생성된 output의 creator(함수)는 자기자신이므로 self를 지정
self.input = input #input
self.output = output #output을 클래스 변수로 지정해두는데 이는 역전파를 할때 용이하게 하기 위함
return output
def forward(self,x): #forward와 backward의 구체적 내용은 상속을 통해 지정한다
raise NotImplementedError()
def backward(self,gy):
raise NotImplementedError()
#Function 클래스를 상속받는다. 구현하고 싶은 함수의 내용을 forward, 그 forward의 미분 공식을 backward에 구현한다.
class Square(Function):
def forward(self,x):
return x**2
def backward(self,gy):
return gy*(self.input.data)*2
class Exp(Function):
def forward(self,x):
return np.exp(x)
def backward(self,gy):
return gy*np.exp(self.input.data)
def square(x): #함수로 간소화
return Square()(x) #function 객체를 반환함은 여전하다
def exp(x):
return Exp()(x)