PyTorch에서 나온 리서처의 편의(?)를 돕기 위한 라이브러리인데요, 공식 홈페이지에서는 이렇게 소개합니다.
You do the research. Lightning will do everything else.
사용해보니 많이 공감되는 말이긴 하지만, 제가 느낀 장점은 코드가 깔끔해진다는 것이 가장 컸습니다.
PyTorch Lightning은 3가지 Module에 대한 사용법만 익히면 끝입니다.
순서대로 LightningModule
, LightningDataModule
, 그리고 Trainer
입니다.
일종의 PyTorch의 한단계 high-level 언어라고 생각하면 되고, 기존의 PyTorch code들을 가져다가 묶어주는 Module들이라고 생각하면 좋습니다.
각각의 Module의 역할은 아래와 같습니다.
LightningModule
- model, loss
LightningDataModule
- data
Trainer
- main loop (train, valid, test)
예시 코드를 살펴보면 이해가 될텐데요, 각 module들은 위에서 언급한 기능들을 구현하는 함수를 가지고 있습니다. 이 함수들만 적절히 구현해주면 모든 구현이 완료됩니다.
import pytorch_lightning as pl
import torch.nn as nn
from torchvision.models import resnet34
import torch.nn.functional as F
class MainModule(pl.LightningModule):
def __init__(self, opt):
super(MainModule, self).__init__()
self.opt=opt
self.model = resnet34()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
im, label = batch
pred = self.model(im)
loss = F.cross_entropy(pred, label)
self.log('train/loss', loss.item())
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.model.parameters(), self.opt.lr)
return [optimizer]
class MainDataModule(pl.LightningDataModule):
def setup(self, stage=None):
self.train_dataset = DatasetClass(self.hparams.data_dir)
self.valid_dataset = DatasetClass(self.hparams.data_dir)
def train_dataloader(self):
return DataLoader(
dataset=self.train_dataset,
batch_size=32
)
def valid_dataloader(self):
return DataLoader(
dataset=self.valid_dataset,
batch_size=32
)
module = MainModule()
datamodule = MainDataModule()
tqdm_cb = TQDMProgressBar(refresh_rate=10)
ckpt_cb = ModelCheckpoint(
dirpath='./saved',
filename="{epoch:02d}_",
save_last=True
)
tb_logger = TensorBoardLogger(
name='exp_name',
save_dir='./log'
)
trainer = pl.Trainer(
accelerator="gpu",
devices=[gpu_number],
max_epochs=100,
callbacks=[tqdm_cb, ckpt_cb],
logger=tb_logger
)
trainer.fit(module, datamodule=datamodule)
개인적으로 생각하는 장점들은 아래와 같습니다.
코드 정형화
main.py (train.py, test.py) 코드 짧아짐
제공되는 기능이 많음.
lightning 모듈, data 모듈까지 설명해주셔서 감사합니다.
도움 많이 되었습니다!