import torch
from torch import nn
class Add(nn.Module):
def __init__(self.value):
super().__init__()
self.value = value
def forward(self, x):
return x + self.value
forward함수는 nn.Module을 상속받은 클래스로부터 만든 model 객체를 데이터와 함께 호출하면 자동으로 실행이 된다. 예를 들어서 model이라는 이름의 객체를 생성하고, (위의 경우에) model(x)와 같이 객체를 호출하면 자동으로 forward연산이 수행된다.
https://jungeui.tistory.com/26
위의 모듈을 사용하고 싶을때는 아래와 같이 사용하면 된다.
model = Add(value) <- value는 model의 value로 들어간다.
print(model(x)) <- x는 forward함수의 인자로 들어간다.
# 이런식으로도 가능
Add(value)(x) <- value는 객체 초기화할때 self.value로 들어가고 x는 forward 함수의 인자로 들어감
sequential은 순차적인 컨테이너다 모듈들이 이 안에 더해지며 컨테이너 안에서 순서대로 실행된다.
calculator = nn.Sequential(Add(1),
Add(2),
Add(3))
x = 1
output = calculator(x)
## output == 7
# 이렇게도 실행 가능
# nn.Sequential(Add(1), Add(2), Add(3))(1)
calculator(x)는 nn.Sequential(Add(1), Add(2), Add(3))(x)로 볼 수 있다. 실행 과정은 아래와 같(은 걸로 추청..?)한다.
tmp = Add(1)(x)
tmp = Add(2)(tmp)
tmp = Add(3)(tmp)
output = tmp
class CustomDataset(Dataset):
def __init__(self):
# 데이터를 불러오고 data와 target을 넣어준다.
def __len__(self):
# 데이터의 길이를 반환해준다.
def __getitem__(self):
# data와 target에 인덱스를 먹이고 어떤 형태로 반환할 것인지 정의한다.
# 이미지의 경우 transforms도 정의해줄 수 있다.
https://pytorch.org/docs/stable/data.html
Dataloader는 dataset의 index를 이용해 배치 단위로 데이터를 제공한다.
인자로는 아래와 같이 있다.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
# 참고: https://pytorch.org/docs/stable/data.html
이 중 batch_size나 collate_fn인자는 자주 사용된다고 한다.
참고: 부스트캠프 AI Tech 4기 Pytorch강의 [최성철 교수님]