pytorch lightning

hyunsooo·2022년 1월 3일
0

파이토치 라이트닝(Pytorch lightning)은 기존의 파이토치에 대한 high level의 인터페이스를 제공하는 라이브러리이다.

파이토치 라이트닝은 GPU, TPU사용과 16-bit precision, 분산 학습과 학습/추론, 데이터로드 등의 부분을 한번에 모듈화할 수 있는 라이브러리이다.

LightningModule

pytorch의 nn.Module의 상위 클래스인 LightningModule을 구현하여 trainer와 모델이 상호작용할 수 있다. LightningModule에서는 오버라이딩하여 편리하게 사용할 수 있는 메서드들이 많이 있다.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl

class MyModel(pl.LightningModule):
	def __init__(self):
    	super(MyModel, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64,3))
        self.decoder = nn.Sequential(nn.Linear(3,64), nn.ReLU(), nn.Linear(64, 28 * 28))
        
    # forward를 구현하면 다른 메서드에서 self(입력)을 사용한 결과를 얻을 수 있다.
    # in Lightning, forward defines the prediction/inference actions
    def forward(self, x):
    	embedding = self.encoder(x)
        return embedding
        
     # train loop, forward와 독립적으로 실행됨
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        
        # tensorboardlogger에 기록
        self.log("train_loss", loss)
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        # forward 호출 self(입력)
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        return val_loss
        
    def validation_epoch_end(self, outputs):
        # validation 에폭의 마지막에 호출됨
        # outputs은 각 batch마다 validation_step에 return의 배열
        # outputs = [{'loss' : batch_0_loss}, {'loss': batch_1_loss}...]
        
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss' : avg_loss}
        return {'avg_val_loss' : avg_loss, 'log':tensorboard_logs}
        
    def test_step(self, batch, batch_idx):
    	# shared_eval_step을 사용해 validation_step, test_step 에서 공동으로 사용가능
        loss, acc = self._shared_eval_Step(batch, batch_idx)
        return loss
        
    # train, val loop가 동일한 경우 shared_step으로 재사용가능
    def shared_step(self, batch):
        x, y= batch
        ...
        return F.cross_entropy(y_hat, y)
         
    def _shared_eval_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)
        return loss, acc
        
    # multiple predict dataloaders일때 dataloader_idx인수사용
    # train_step, test_step등에서도 사용가능
    def predict_step(self, batch, batch_idx, dataloader_idx):
        x,y= batch
        x = x.view(x.size(0), -1)
        return self.encoder(x)
        
    def configure_optimizers(self):
        # optimizer/scheduler 설정
        optimizer = torch.optim.Adam()
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
        return [optimizer], [scheduler]
  • configure_optimizers return
    • single optimizer
    • List or Tuple of optimizers
    • Two lists : 첫번째 list는 multiple optimizers, 두번째는 multiple LR schedulers
    • Dictionary : "optimizer", (optionally) "lr_schduler" key사용
      {"optimizer" : optimizer,
      "lr_schduler" : {
      "scheduler" : LRScheduler(optimizer, ...),
      "monitor" : "metric_to_track",
      "frequency" : metric update 수 지정}
      }
      }
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )
# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

freeze

model = MyLightningModule(...)
model.freeze()

unfreeze

model = MyLightningModule(...)
model.unfreeze()

log

self.log('train_loss', loss)

parameter

  • name : log의 key값
  • value : log의 value값 (float, Tensor, Metric or dictionary)
  • prog_bar : True시 progress bar 기록
  • on_step : True시 step 단계로 기록
  • on_epoch : True시 epoch 단계 metric을 기록
  • logger : True시 logger기록

training_step : on_step = T, on_epoch = F, prog_bar = F, logger = T

training_step_end : on_step = T, on_epoch = F, prog_bar = F, logger = T

training_epoch_end : on_step = F, on_epoch = T, prog_bar = F, logger = T

validation_step(apply test loop) : on_step = F, on_epoch = T, prog_bar = F, logger = T

validation_step_end(apply test loop) : on_step = F, on_epoch = T, prog_bar = F, logger = T

validation_epoch_end(apply test loop) : on_step = F, on_epoch = T, prog_bar = F, logger = T

save_hyperparameters

class ArgsModel(HyperparametersMixin):
    def __init__(self, arg1, arg2, arg3):
        super().__init__()
        self.save_hyperparameters('arg1', 'arg3')
        
    def forward(self, *args, **kwargs):
        ...

model = ArgsModel(1, 'abc', 3.14)
model.hparams

"arg1" : 1
"arg3" : 3.14

class ArgsModel(HyperparametersMixin):
    def __init__(self, arg1, arg2, arg3):
        super().__init__()
        self.save_hyperparameters()
        
    def forward(self, *args, **kwargs):
        ...

model = ArgsModel(1, 'abc', 3.14)
model.hparams

"arg1" : 1
"arg2" : abc
"arg3" : 3.14

class ArgsModel(HyperparametersMixin):
    def __init__(self, params):
        super().__init__()
        self.save_hyperparameters(params)
        
    def forward(self, *args, **kwargs):
        ...

model = ArgsModel(Namespace(p1=1, p2='abc', p3=3.14))
model.hparams

"p1" : 1
"p2" : abc
"p3" : 3.14

class ArgsModel(HyperparametersMixin):
    def __init__(self, arg1, arg2, arg3):
        super().__init__()
        self.save_hyperparameters(ignore='arg2')
        
    def forward(self, *args, **kwargs):
        ...

model = ArgsModel(1, 'abc', 3.14)
model.hparams

"arg1" : 1
"arg3" : 3.14

Trainer

Basic use

pl.seed_everything(seed)

model = MyModel()
trainer = Trainer()
trainer.fit(model)
  • test
trainer.test()
  • prediction
pretrained_model = LightningModule.load_from_checkpoint(path)
pretrained_model.freeze()

# finetuning

def forward(self, x):
    features = =pretrained_model(x)
    classed = classifier(featrues)
    
# prediction

out = pretrained_model(x)

Argparser

  1. Trainer args ( gpus, num_does, etc...)
  2. Model specific arguments (layer_num, num_layers, learning_rate, etc...)
  3. Program arguments(data_path, cluster_email, etc...

LightningModule에서 add_model_specific_args를 구현


class Base(pl.LightningModule):
    def __init__(self, hparams, **kwargs):
        super(Base, self).__init__()
        self.hparams = hparams
        
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        
        parser.add_argument('--batch-size', type=int, default=10, help ='batch size (default:10))
        parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
        parser.add_argument('--data_path', type=str, default=None, help='data path')
        
        return parser
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--checkpoint_path', type=str, help='checkpoint path')

parser = Base.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)

args = parser.parse_args()
model = MyModel(args)
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)

trainer.fit(model)

LightningDataModule

LightningDataModule은 pytorch의 데이터처리의 5가지 step을 캡슐화한 모듈이다.

  1. Download/ tokenize/ process
  2. Clean and save to disk
  3. Load inside Dataset
  4. Apply transforms(rotate, tokenize, etc...)
  5. Wrap inside a DataLoader

여러 메소드 -train_dataloader, val_dataloader, test_dataloader, setupLightningDataModule에 구현하여 사용할 수 있다.

# setup은 train, valid, test, and predict로 데이터를 나눈다.
    def setup(self, stage):
        self.train = Dataset(...)
        self.valid = Dataset(...)
        self.test = Dataset(...)
        
    def train_dataloader(self):
        train = DataLoader(...)
        return train
        
    def val_dataloader(self):
        val = DataLoader(...)
        return val
        
    def test_dataloader(self):
        test = DataLoader(...)
        return test

LightningDataModule API

DataModule에서 5가지 메서드를 정의할 수 있다.

  • prepare_data : process에서 한번 호출(하나의 스크립트를 실행할때 한번만 호출되기 때문에 데이터 다운로드나 데이터 변형과 같은 작업 수행)
  • setup : gpu마다 호출하여 수행
  • train_dataloader
  • val_dataloader
  • test_dataloader

선택적으로 predict_dataloader를 정의할 수 있다.

prepare_data

  • download
  • tokenize
  • etc...
class MNISTDataModule(pl.LightningDataModule):
    def prepare_data(self):
        # download
        self.trainning_data=MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
        

setup

  • count number of classes
  • build vocabulary
  • perform train/val/test splits
  • apply transforms(defined explicitly in your datamodule)
  • etc
class MNISTDataModule(pl.LightningDataModule):
    def setup(self):
        # Assign Train/val split(s) for use in Dataloaders
        if stage in (None, "fit"):
            self.mnist_train, self.mnist_val=random_split(self.training_data, [55000, 5000])
            
        if stage in (None, "test"):
            self.mnist_test = MNSIT(self.data_dir, train=False, download=True, transformm=self.transform)
        

loader

class MNISTDataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)
        
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)
        
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

Using DataModule


dm = MNISTDataModule()
model = MyModel()
trainer = pl.trainer(...)
trainer.fit(model, dm)
trainer.test(datamodule=dm)

Model Checkpointing

checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback])

parameter

  • dirpath : model 파일을 저장하기 위한 디렉토리 path dirpath = "my/logs/"

  • filename : checkpoint의 filename. formatting을 사용할수 있음 filename = {epoch}-{val_loss:2f}-}, filename = 'model_chp/{epoch}-{val_loss:.2f}로 logs디렉토리 아래에 추가 디렉토리를 생성하여 저장할 수 있다.

  • monitor : 모니터링할 key값, default는 None으로 마지막 에폭에만 저장된다. monitor='val_loss'

  • save_last : True시 마지막 에폭에 대한last.ckpt파일을 저장한다. default = None

  • save_top_k : save_top_k == k(int)값을 주면, 모니터링하는 값이 가장 좋은 k개의 모델을 저장한다. k가 0이면 모델을 저장하지 않고, k가 -1이면 모든 모델(에폭별)이 저장된다.

  • mode : save_top_k != 0이면 현재 저장된 모델 파일을 덮어씌우기 위한 선택은 모니터링 되는 값의 최대값 또는 최소값이 기준이 된다. 정확도의 경우 mode == max여야 하고, loss의 경우 mode == min이여야 한다.

  • s

Early stopping

Trainer의 callbacks인자로 지정하여 사용할 수 있다.

early_stopping = EarlyStopping(
								monitor='val_acc',
                                patience=args.patience,
                                verbose=True,
                                mode='max'
                                )

trainer = pl.Trainer(callbacks=[checkpoint_callback, early_stopping]

Tip
추천 num_workers
num_workers = 4 * num_GPU
16bit 제공
Trainer(precision = 16)

Tip
profile 제공
Trainer(profile=True)
원하는 profile을 인수로 제공할 수 있음
profiler = AdvancedProfiler()
trainer = Trainer(profiler=profiler)

가중치 저장 및 로드

trainer를 사용하면 마지막 훈련 epoch에 대한 ckpt파일을 자동으로 저장한다.

class MyModel(LightningModule):
	def validation_step(self, batch, batch_idx):
    	x, y = batch
        y_hat = self(x)
        
        loss = F.cross.entropy(y_hat, y)
        self.log('val_loss', loss)
        
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
trainer = Trainer(callbacks = [checkpoint_callback])

monitoring할 key값을 제공하여 저장이 가능함

checkpoint 비활성화

trainer = Trainer(checkpoint_callback=False)

수동 저장

trainer.fit(model)
trainer.save_checkpoint("sample.ckpt")

new_model = Mymodel.load_from_checkpoint(checkpoint_path="sample.ckpt")

Restroing Training State


model = Mymodel()
trainer = Trainer()

trainer.fit(model, ckpt_path="./ckpt/my_checkpoint.ckpt")

predict


load_model = MyModel.load_from_checkpoint("ckpt file")

load_model.eval()
load_model.freeze()

y_hat = load_model(x)
profile
지식 공유

0개의 댓글