Pytorch Lightning 시작하기

Juppi·2022년 12월 23일
0
post-custom-banner

📌 pytorch에 대한 high-level 인터페이스를 제공하는 오픈 소스 라이브러리

GPU나 TPU, 16bit precision, 분산학습 등 더욱 복잡한 조건에서 실험하게 될 경우 코드가 복잡해진다.

pytorch lightning의 목적은 코드의 추상화를 통해 좀 더 정돈되고 간결화된 코드를 작성하는 데에 있다.

pytorch lightning은 Lightning Model class 내에 DataLoader, Model, optimizer, Training roof 등을 한번에 구현하도록 되어 있다. 클래스 내부에 있는 함수명은 똑같이 써야하고 그 목적에 맞게 코딩하면 된다

pytorch lightning은 크게 Trainer와 Lightning Module로 나눌 수 있다.

Lightning Module은 모델 내부의 구조를 설계하는 research & science 클래스라고 생각할 수 있다. 모델의 구조나 데이터 전처리, 손실함수 등의 설정을 통해 모델을 초기화한다. 실제로 코드에서는 pl.LightningModule 클래스를 상속받아 새로운 LightningModule 클래스를 생성한다. 기존 PyTorch의 nn.Module과 같은 방식이라고 보면된다.

Trainer는 모델의 학습을 담당하는 클래스라고 볼 수 있다. 모델의 학습 epoch이나 batch 등의 상태뿐만 아니라, 모델을 저장해 로그를 생성하는 부분까지 담당한다. 실제로 코드에서는 pl.Trainer()라고 정의하면 끝이다.

pytorch lightning은 아래 명령어를 통해 간단히 설치할 수 있다

pip install pytorch-lightning

pytorch lightning을 통해 딥러닝 모델을 작성하는 순서는 아래와 같다.

  1. lightning module에서 상속된 새로운 Lightning Module 클래스를 작성한다.
  2. DataLoader를 통해 학습할 데이터를 준비한다.
  3. Trainer 객체를 만들고, 그 Trainer에 데이터와 Lightning Module 클래스를 주어 학습한다.

Lightning Module Class

lightning module은 trainer와 model이 상호작용할 수 있게 해주는 구현체이다.

기존 pytorch와 달리, DataLoader, Model, optimizer, Training loof 등을 Lightning Module Class 안에 한번에 구현하도록 되어있다.

모듈 정의를 위해 LightningModule 클래스를 상속받고 학습에 필요한 메서드를 구현해야한다.

Lightning Module은 아래 6가지로 구성된다

  • Computations (__init__)

    • 초기화 메서드
    • Lightning Module Calss에서 사용할 신경망을 정의한다
    • 신경망 레이어를 생성하려면 torch.nn.module 에서 불러오거나 확장해야한다
  • Train loop (training_step)

    • 학습 루프의 body 부분을 나타낸다.
    • nn.Moduleforward와 유사하지만, 단일 배치에서의 loss를 반환해야하며, 이는 train loop로 자동 반복된다
    • 이 메소드에서는 argument로 training dataloader가 제공하는 batch와 해당 batch의 index가 주어지고, 학습 loss를 계산하여 반환한다.
    • lightning에서는 batch의 tensor를 cpu/gpu tensor로 변경하는 코드를 따로 추가하지 않아도 trainer의 설정에 따라 자동으로 적절한 타입으로 변경해준다.
    • epoch level metric을 계산하고 log를 추적하려면 .log 메서드를 사용해야한다. 만약, training_step의 결과를 무엇인가 할 일이 있으면 training_epoch_end 메서드를 작성한다.
      def training_step(self, batch, batch_idx):
          x, y = batch
          y_hat = self.model(x)
          loss = F.cross_entropy(y_hat, y)
          # logs metrics for each training_step,
          # and the average across the epoch, to the progress bar and logger
          self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
          return loss
      
      def training_epoch_end(self, training_step_outputs):
          for pred in training_step_outputs:
              ...
  • Validation loop (validation_step)

    • loss 및 metric logging을 위한 validation_step및 test_step을 추가할 수 있다.
    • validation_step은 학습 중간에 모델의 성능을 체크하는 용도로 사용한다.
    • training_step과 마찬가지로 validation 데이터로더에서 제공하는 배치를 가지고 확인하고자 하는 통계량을 기록할 수 있습니다.
    • 만약에 각 validation_step의 결과로 무엇인가 할 일이 있으면 validation_epoch_end 메서드에 작성합니다.
      def validation_step(self, batch, batch_idx):
          x, y = batch
          y_hat = self.model(x)
          loss = F.cross_entropy(y_hat, y)
          self.log("val_loss", loss)
          pred = ...
          return pred
      
      def validation_epoch_end(self, validation_step_outputs):
          for pred in validation_step_outputs:
              ...
  • Test loop (test_step)
    • test_step은 앞의 두 함수와 비슷하게 test 데이터로더에서 제공하는 배치를 가지고 확인하고 싶은 통계량을 기록하는데 사용할 수 있다.
    • test loop 코드는 validation loop 코드와 거의 동일하다.
    • 호출할 때는 test_step()메서드 를 재정의해야한다.
      # call after training
      trainer = Trainer()
      trainer.fit(model)
      
      # automatically auto-loads the best weights from the previous run
      trainer.test(dataloaders=test_dataloader)
  • Prediction loop (predict_step)

  • Optimizers (configure_optimizers)

아래는 뼈대 코드이다

import pytorch_lightning as pl

class Classifier(pl.LightningModule):
	def __init__(self):
		super().__init__()
		self.model = nn.Sequential(
				...
		)

	def forward(self, x):
		pass

	def training_step(self, batch, batch_idx):
		pass

	def validation_step(self, batch, batch_idx):
		pass

	def test_step(self, batch, batch_idx):
		pass

	def configure_optimizers(self):
		pass

reference

https://wikidocs.net/156985

profile
잠자면서 돈버는 그날까지
post-custom-banner

0개의 댓글