📌 pytorch에 대한 high-level 인터페이스를 제공하는 오픈 소스 라이브러리
GPU나 TPU, 16bit precision, 분산학습 등 더욱 복잡한 조건에서 실험하게 될 경우 코드가 복잡해진다.
pytorch lightning의 목적은 코드의 추상화를 통해 좀 더 정돈되고 간결화된 코드를 작성하는 데에 있다.
pytorch lightning은 Lightning Model class 내에 DataLoader, Model, optimizer, Training roof 등을 한번에 구현하도록 되어 있다. 클래스 내부에 있는 함수명은 똑같이 써야하고 그 목적에 맞게 코딩하면 된다
Lightning Module은 모델 내부의 구조를 설계하는 research & science 클래스라고 생각할 수 있다. 모델의 구조나 데이터 전처리, 손실함수 등의 설정을 통해 모델을 초기화한다. 실제로 코드에서는 pl.LightningModule 클래스를 상속받아 새로운 LightningModule 클래스를 생성한다. 기존 PyTorch의 nn.Module과 같은 방식이라고 보면된다.
Trainer는 모델의 학습을 담당하는 클래스라고 볼 수 있다. 모델의 학습 epoch이나 batch 등의 상태뿐만 아니라, 모델을 저장해 로그를 생성하는 부분까지 담당한다. 실제로 코드에서는 pl.Trainer()라고 정의하면 끝이다.
pytorch lightning은 아래 명령어를 통해 간단히 설치할 수 있다
pip install pytorch-lightning
pytorch lightning을 통해 딥러닝 모델을 작성하는 순서는 아래와 같다.
lightning module은 trainer와 model이 상호작용할 수 있게 해주는 구현체이다.
기존 pytorch와 달리, DataLoader, Model, optimizer, Training loof 등을 Lightning Module Class 안에 한번에 구현하도록 되어있다.
모듈 정의를 위해 LightningModule 클래스를 상속받고 학습에 필요한 메서드를 구현해야한다.
Lightning Module은 아래 6가지로 구성된다
Computations (__init__
)
torch.nn.module
에서 불러오거나 확장해야한다Train loop (training_step
)
nn.Module
의 forward
와 유사하지만, 단일 배치에서의 loss를 반환해야하며, 이는 train loop로 자동 반복된다.log
메서드를 사용해야한다. 만약, training_step의 결과를 무엇인가 할 일이 있으면 training_epoch_end
메서드를 작성한다.def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def training_epoch_end(self, training_step_outputs):
for pred in training_step_outputs:
...
Validation loop (validation_step
)
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
pred = ...
return pred
def validation_epoch_end(self, validation_step_outputs):
for pred in validation_step_outputs:
...
test_step
)# call after training
trainer = Trainer()
trainer.fit(model)
# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)
Prediction loop (predict_step
)
Optimizers (configure_optimizers
)
아래는 뼈대 코드이다
import pytorch_lightning as pl
class Classifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
...
)
def forward(self, x):
pass
def training_step(self, batch, batch_idx):
pass
def validation_step(self, batch, batch_idx):
pass
def test_step(self, batch, batch_idx):
pass
def configure_optimizers(self):
pass
reference