nn.Module
을 이용하여 리팩토링 하기torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
import torch.nn as nn
from torch.autograd import Variable
input = torch.ones(1,1,3,3)
input = Variable(input, requires_grad=True)
func = nn.Conv2d(1,1,3) # input, ouput, kernel_size
print(func.weight)
out = func(input)
print(out)
out.backward()
print(input.grad)
PyTorch의 nn
클래스의 장점을 활용하여 코드를 더 간결하고 유연하게 만들 수 있음
활성화, 손실 함수를 torch.nn.functional
의 함수로 대체
함수: 인스턴스화 시킬 필요 없이 사용 가능
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1 groups=1)
변수 설명
1) weight: 외부에서 만든 필터를 넣어줘야함
2) 나머지는 동일
import torch.nn.functional as F
from torch.autograd import Variable
input = torch.ones(1,1,3,3)
fileter = torch.ones(1,1,3,3)
input = Varialbe(input, requires_grad=True)
fileter = Variable(filter)
out = F.conv2d(input, filter)
out.backward()
print(out_grad_fn) #ConvNdBackward object at~>