[DL] Pytorch Lightning 사용법

dj_·2023년 1월 16일
2

PyTorch Lightning

PyTorch에서 나온 리서처의 편의(?)를 돕기 위한 라이브러리인데요, 공식 홈페이지에서는 이렇게 소개합니다.

You do the research. Lightning will do everything else.

사용해보니 많이 공감되는 말이긴 하지만, 제가 느낀 장점은 코드가 깔끔해진다는 것이 가장 컸습니다.

Lightning 사용법

PyTorch Lightning은 3가지 Module에 대한 사용법만 익히면 끝입니다.
순서대로 LightningModule, LightningDataModule, 그리고 Trainer 입니다.

일종의 PyTorch의 한단계 high-level 언어라고 생각하면 되고, 기존의 PyTorch code들을 가져다가 묶어주는 Module들이라고 생각하면 좋습니다.

각각의 Module의 역할은 아래와 같습니다.

  1. LightningModule - model, loss

    • train, validation, test loop 내용 정의
    • optimizer, learning rate scheduler 등을 정의
    • model, loss 정의
  2. LightningDataModule - data

    • Dataset, DataLoader 정의 (train, valid, test 모두)
    • 데이터 다운로드 (필요할 경우)
  3. Trainer - main loop (train, valid, test)

    • 보통 main.py 같은 파일에서 call
    • gpu number, epoch 수, logger(tensorboard 같은) 설정
    • 실제 train,test,valid loop 수행

PyTorch Lightning Code

예시 코드를 살펴보면 이해가 될텐데요, 각 module들은 위에서 언급한 기능들을 구현하는 함수를 가지고 있습니다. 이 함수들만 적절히 구현해주면 모든 구현이 완료됩니다.

  1. LightningModule
  • train loop에서 수행할 내용 정의 -> training_step
    - validation_step, test_step 함수도 만들어서 정의할 수 있음!
  • optimizer, scheduler 설정 -> configure_optimizers
import pytorch_lightning as pl 
import torch.nn as nn 
from torchvision.models import resnet34
import torch.nn.functional as F
class MainModule(pl.LightningModule):
	def __init__(self, opt):
		super(MainModule, self).__init__()
		self.opt=opt 	
		self.model = resnet34()
		
	def forward(self, x):
		return self.model(x)
        
	def training_step(self, batch, batch_idx):
		im, label = batch 
		pred = self.model(im)
		loss = F.cross_entropy(pred, label)
		self.log('train/loss', loss.item())
		return loss 
	
	def configure_optimizers(self):
		optimizer = torch.optim.Adam(
			self.model.parameters(), self.opt.lr)
		return [optimizer]
  1. LightningDataModule
  • dataset 정의 - setup
  • dataloader 정의 - train_dataloader
class MainDataModule(pl.LightningDataModule):
	def setup(self, stage=None):
    	self.train_dataset = DatasetClass(self.hparams.data_dir)
        self.valid_dataset = DatasetClass(self.hparams.data_dir)
        
   	def train_dataloader(self):
    	return DataLoader(
        	dataset=self.train_dataset,
            batch_size=32
        )
    
    def valid_dataloader(self):
   		return DataLoader(
        	dataset=self.valid_dataset,
            batch_size=32
        )
  1. Trainer (main.py)
  • 각 module들, 필요한 callback들 만들어서 객체 생성
  • training -> fit 함수
  • test -> test 함수
	module = MainModule()
    datamodule = MainDataModule()
    tqdm_cb = TQDMProgressBar(refresh_rate=10)
    ckpt_cb = ModelCheckpoint(
    	dirpath='./saved',
        filename="{epoch:02d}_",
        save_last=True
    )
    tb_logger = TensorBoardLogger(
    	name='exp_name',
        save_dir='./log'
    )
    
    trainer = pl.Trainer(
    	accelerator="gpu",
        devices=[gpu_number],
        max_epochs=100,
        callbacks=[tqdm_cb, ckpt_cb],
        logger=tb_logger
    )
    trainer.fit(module, datamodule=datamodule)

장점

개인적으로 생각하는 장점들은 아래와 같습니다.

  1. 코드 정형화

    • 아무래도 각 모듈들로 역할들을 분배해놓아서, 코드 구조가 거의 다 똑같다.
  2. main.py (train.py, test.py) 코드 짧아짐

    • 보통 PyTorch 코드들은 실제 loop를 수행하는 train.py, test.py 파일들을 구현합니다.
    • 저의 경우 이 파일들이 실험을 돌릴수록 점점 더러워졌는데..
    • module들로 model, data 관련 내용을 분리하고, callback, logger 기능들을 활용하면서 너어어어어어무 깔끔해졌다고 생각합니다.
  3. 제공되는 기능이 많음.

    • 가장 체감이 컸던건 tqdm, logger를 callback으로 넣어서 직접 구현할 일이 없다는 것입니다.
    • 이외에도 early stopping, 16-bit precision 등 기본적으로 제공해주는 기능들이 많아서 편합니다.

0개의 댓글