파이토치 라이트닝(Pytorch lightning)은 기존의 파이토치에 대한 high level의 인터페이스를 제공하는 라이브러리이다.
파이토치 라이트닝은 GPU, TPU사용과 16-bit precision, 분산 학습과 학습/추론, 데이터로드 등의 부분을 한번에 모듈화할 수 있는 라이브러리이다.
pytorch의 nn.Module의 상위 클래스인 LightningModule을 구현하여 trainer와 모델이 상호작용할 수 있다. LightningModule에서는 오버라이딩하여 편리하게 사용할 수 있는 메서드들이 많이 있다.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
def __init__(self):
super(MyModel, self).__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64,3))
self.decoder = nn.Sequential(nn.Linear(3,64), nn.ReLU(), nn.Linear(64, 28 * 28))
# forward를 구현하면 다른 메서드에서 self(입력)을 사용한 결과를 얻을 수 있다.
# in Lightning, forward defines the prediction/inference actions
def forward(self, x):
embedding = self.encoder(x)
return embedding
# train loop, forward와 독립적으로 실행됨
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
# tensorboardlogger에 기록
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
# forward 호출 self(입력)
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
return val_loss
def validation_epoch_end(self, outputs):
# validation 에폭의 마지막에 호출됨
# outputs은 각 batch마다 validation_step에 return의 배열
# outputs = [{'loss' : batch_0_loss}, {'loss': batch_1_loss}...]
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss' : avg_loss}
return {'avg_val_loss' : avg_loss, 'log':tensorboard_logs}
def test_step(self, batch, batch_idx):
# shared_eval_step을 사용해 validation_step, test_step 에서 공동으로 사용가능
loss, acc = self._shared_eval_Step(batch, batch_idx)
return loss
# train, val loop가 동일한 경우 shared_step으로 재사용가능
def shared_step(self, batch):
x, y= batch
...
return F.cross_entropy(y_hat, y)
def _shared_eval_step(self, batch, batch_idx):
x,y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
return loss, acc
# multiple predict dataloaders일때 dataloader_idx인수사용
# train_step, test_step등에서도 사용가능
def predict_step(self, batch, batch_idx, dataloader_idx):
x,y= batch
x = x.view(x.size(0), -1)
return self.encoder(x)
def configure_optimizers(self):
# optimizer/scheduler 설정
optimizer = torch.optim.Adam()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
return [optimizer], [scheduler]
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
# most cases. no learning rate scheduler
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)
# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
return gen_opt, dis_opt
# example with learning rate schedulers
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
dis_sch = CosineAnnealing(dis_opt, T_max=10)
return [gen_opt, dis_opt], [dis_sch]
model = MyLightningModule(...)
model.freeze()
model = MyLightningModule(...)
model.unfreeze()
self.log('train_loss', loss)
parameter
float
, Tensor
, Metric
or dictionary
)training_step : on_step = T, on_epoch = F, prog_bar = F, logger = T
training_step_end : on_step = T, on_epoch = F, prog_bar = F, logger = T
training_epoch_end : on_step = F, on_epoch = T, prog_bar = F, logger = T
validation_step(apply test loop) : on_step = F, on_epoch = T, prog_bar = F, logger = T
validation_step_end(apply test loop) : on_step = F, on_epoch = T, prog_bar = F, logger = T
validation_epoch_end(apply test loop) : on_step = F, on_epoch = T, prog_bar = F, logger = T
class ArgsModel(HyperparametersMixin):
def __init__(self, arg1, arg2, arg3):
super().__init__()
self.save_hyperparameters('arg1', 'arg3')
def forward(self, *args, **kwargs):
...
model = ArgsModel(1, 'abc', 3.14)
model.hparams
"arg1" : 1
"arg3" : 3.14
class ArgsModel(HyperparametersMixin):
def __init__(self, arg1, arg2, arg3):
super().__init__()
self.save_hyperparameters()
def forward(self, *args, **kwargs):
...
model = ArgsModel(1, 'abc', 3.14)
model.hparams
"arg1" : 1
"arg2" : abc
"arg3" : 3.14
class ArgsModel(HyperparametersMixin):
def __init__(self, params):
super().__init__()
self.save_hyperparameters(params)
def forward(self, *args, **kwargs):
...
model = ArgsModel(Namespace(p1=1, p2='abc', p3=3.14))
model.hparams
"p1" : 1
"p2" : abc
"p3" : 3.14
class ArgsModel(HyperparametersMixin):
def __init__(self, arg1, arg2, arg3):
super().__init__()
self.save_hyperparameters(ignore='arg2')
def forward(self, *args, **kwargs):
...
model = ArgsModel(1, 'abc', 3.14)
model.hparams
"arg1" : 1
"arg3" : 3.14
pl.seed_everything(seed)
model = MyModel()
trainer = Trainer()
trainer.fit(model)
trainer.test()
pretrained_model = LightningModule.load_from_checkpoint(path)
pretrained_model.freeze()
# finetuning
def forward(self, x):
features = =pretrained_model(x)
classed = classifier(featrues)
# prediction
out = pretrained_model(x)
gpus
, num_does
, etc...)layer_num
, num_layers
, learning_rate
, etc...)data_path
, cluster_email
, etc...LightningModule
에서 add_model_specific_args
를 구현
class Base(pl.LightningModule):
def __init__(self, hparams, **kwargs):
super(Base, self).__init__()
self.hparams = hparams
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--batch-size', type=int, default=10, help ='batch size (default:10))
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--data_path', type=str, default=None, help='data path')
return parser
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--checkpoint_path', type=str, help='checkpoint path')
parser = Base.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
model = MyModel(args)
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)
trainer.fit(model)
LightningDataModule은 pytorch의 데이터처리의 5가지 step을 캡슐화한 모듈이다.
Dataset
DataLoader
여러 메소드 -train_dataloader
, val_dataloader
, test_dataloader
, setup
을 LightningDataModule
에 구현하여 사용할 수 있다.
# setup은 train, valid, test, and predict로 데이터를 나눈다.
def setup(self, stage):
self.train = Dataset(...)
self.valid = Dataset(...)
self.test = Dataset(...)
def train_dataloader(self):
train = DataLoader(...)
return train
def val_dataloader(self):
val = DataLoader(...)
return val
def test_dataloader(self):
test = DataLoader(...)
return test
LightningDataModule API
DataModule에서 5가지 메서드를 정의할 수 있다.
prepare_data
: process에서 한번 호출(하나의 스크립트를 실행할때 한번만 호출되기 때문에 데이터 다운로드나 데이터 변형과 같은 작업 수행)setup
: gpu마다 호출하여 수행train_dataloader
val_dataloader
test_dataloader
선택적으로 predict_dataloader
를 정의할 수 있다.
class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# download
self.trainning_data=MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
class MNISTDataModule(pl.LightningDataModule):
def setup(self):
# Assign Train/val split(s) for use in Dataloaders
if stage in (None, "fit"):
self.mnist_train, self.mnist_val=random_split(self.training_data, [55000, 5000])
if stage in (None, "test"):
self.mnist_test = MNSIT(self.data_dir, train=False, download=True, transformm=self.transform)
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
dm = MNISTDataModule()
model = MyModel()
trainer = pl.trainer(...)
trainer.fit(model, dm)
trainer.test(datamodule=dm)
checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback])
dirpath : model 파일을 저장하기 위한 디렉토리 path dirpath = "my/logs/"
filename : checkpoint의 filename. formatting을 사용할수 있음 filename = {epoch}-{val_loss:2f}-}
, filename = 'model_chp/{epoch}-{val_loss:.2f}
로 logs디렉토리 아래에 추가 디렉토리를 생성하여 저장할 수 있다.
monitor : 모니터링할 key값, default는 None
으로 마지막 에폭에만 저장된다. monitor='val_loss'
save_last : True
시 마지막 에폭에 대한last.ckpt
파일을 저장한다. default = None
save_top_k : save_top_k
==
k(int)
값을 주면, 모니터링하는 값이 가장 좋은 k개의 모델을 저장한다. k
가 0이면 모델을 저장하지 않고, k
가 -1이면 모든 모델(에폭별)이 저장된다.
mode : save_top_k
!=
0
이면 현재 저장된 모델 파일을 덮어씌우기 위한 선택은 모니터링 되는 값의 최대값 또는 최소값이 기준이 된다. 정확도의 경우 mode
==
max
여야 하고, loss의 경우 mode
==
min
이여야 한다.
s
Trainer의 callbacks인자로 지정하여 사용할 수 있다.
early_stopping = EarlyStopping(
monitor='val_acc',
patience=args.patience,
verbose=True,
mode='max'
)
trainer = pl.Trainer(callbacks=[checkpoint_callback, early_stopping]
Tip
추천 num_workers
num_workers = 4 * num_GPU
16bit 제공
Trainer(precision = 16)
Tip
profile 제공
Trainer(profile=True)
원하는 profile을 인수로 제공할 수 있음
profiler = AdvancedProfiler()
trainer = Trainer(profiler=profiler)
ckpt
파일을 자동으로 저장한다.class MyModel(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross.entropy(y_hat, y)
self.log('val_loss', loss)
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
trainer = Trainer(callbacks = [checkpoint_callback])
monitoring할 key값을 제공하여 저장이 가능함
trainer = Trainer(checkpoint_callback=False)
trainer.fit(model)
trainer.save_checkpoint("sample.ckpt")
new_model = Mymodel.load_from_checkpoint(checkpoint_path="sample.ckpt")
model = Mymodel()
trainer = Trainer()
trainer.fit(model, ckpt_path="./ckpt/my_checkpoint.ckpt")
load_model = MyModel.load_from_checkpoint("ckpt file")
load_model.eval()
load_model.freeze()
y_hat = load_model(x)