- Research 프로젝트는 같은 데이터셋에 대해서 다른 접근법을 취하는 경향이 있음
- Pytorch Lightning에서는 이 부분을 상속을 통해서 아주 쉽게 처리함
- 예를 들면, MNIST 이미지에서 feature를 뽑으려고 AutoEncoder를 학습시킨다고 할 때, dataloader 설정이 되어있는 LitMNIST라는 모듈을 확장해서 사용할 수 있음
- Autoencoder 모델에서 init/forward/training/validation/test step만 변경해주면 됨
class Encoder(torch.nn.Module):
pass
class Decoder(torch.nn.Module):
pass
class AutoEncoder(LitMNIST):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.metric = MSE()
def forward(self, x):
return self.encoder(x)
def training_step(self, batch, batch_idx):
x, _ = batch
representation = self.encoder(x)
x_hat = self.decoder(representation)
loss = self.metric(x, x_hat)
return loss
def validation_step(self, batch, batch_idx):
self._shared_eval(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
self._shared_eval(batch, batch_idx, "test")
def _shared_eval(self, batch, batch_idx, prefix):
x, _ = batch
representation = self.encoder(x)
x_hat = self.decoder(representation)
loss = self.metric(x, x_hat)
self.log(f"{prefix}_loss", loss)
- 이 경우, 기존 사용하던 같은 trainer 인스턴스로 학습이 가능함
autoencoder = AutoEncoder()
trainer = Trainer()
trainer.fit(autoencoder)
- Lightning Module을 Pytorch model 처럼 사용하고 싶은 경우, 꼭 forward method를 구현해 놓아야함
some_images = torch.Tensor(32, 1, 28, 28)
representations = autoencoder(some_images)