Pytorch Lightning 사용기

yslee·2022년 2월 2일
1

Pytorch Lightning

목록 보기
1/2
post-thumbnail

pytorch lightning(PL) Introduction

최근 두 달 pytorch lightning(as PL)을 찍먹하며 쓸수록 괜찮은 프레임워크라는 생각이 든다.

괜찮은 프레임워크라는 이유는 아래와 같다.
0. PL 도입 시 프로젝트 코드 구조를 좀 더 확실하게 가져갈 수 있음 (개인적으로 이게 가장 좋다고 생각)
1. 학습에 많은 부분이 자동화되어 있어 실수를 줄일 수 있음
2. data, model, training 각각의 파트가 나누어져 있어 코드 관리가 용의함
3. mlflow, wandb, tensorboard 등 다양한 logger와 통합이 유리
4. 수동으로 사용하면 pytorch 사용하는 방식을 그대로 사용 가능



먼저 PL은 pytorch, tensorflow(as TF), onnx와 같은 또 다른 프레임워크가 아니다.
PL은 pytorch의 문법을 그대로 사용하면서 pytorch 작업 시 사용되는 보일러플레이트를 제거할 수 있는 Keras와 같은 고수준 API를 제공하는 pytorch의 상위 프레임워크이다. (개인적인 생각)

PL 공식 문서의 PL 철학은 다음과 같다.

Lightning Philosophy
Organizing your code with Lightning makes your code:

  • Keep all the flexibility (this is all pure PyTorch), but removes a ton of boilerplate
  • More readable by decoupling the research code from the engineering
  • Easier to reproduce
  • Less error-prone by automating most of the training loop and tricky engineering
  • Scalable to any hardware without changing your model

Lightning is built for:

  • Researcher who want to focus on research without worrying about the engineering aspects of it
  • ML Engineers who want to built reproducible pipelines
  • Data Scientists who want to try out different models for their tasks and build-in ML techniques
  • Educators who seek to study and teach Deep Learning with PyTorch
    -The team makes sure that all the latest techniques are already integrated and well maintained.
  • pytorch를 그대로 사용할 수 있음
  • 코드가 분리되어 있어 가독성이 좋음
  • 재생산성이 좋고
  • 자동화되어 사람의 실수를 줄일 수 있음
  • 최소한의 코드변경으로 가속기를 변경할 수 있음

철학이 그러하다면 어떤 방식으로 자동화를 진행하는지 알아보자

사용된 예제는 공식문서를 참고한다.

PL LightningModule 정의

먼저 학습 모델을 만드는 과정이다.

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 3),
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28 * 28),
        )

nn.Module 대신 LightningModule을 상속받아 모델을 생성한다.
LightningModule은 nn.Module을 상속받기 때문에 nn.Module의 기능을 모두 가지고 있다.
모델링 작업은 torch와 유사하게 생성자에서 네트워크에 필요한 모듈을 생성한다.
필요하다면 생성자의 파라미터를 추가해 관리하는 것이 가능하다.

    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

LightningModule은 nn.Module을 상속받기 때문에 forward를 오버라이딩 하지 않으면 torch.nn.Module에서 NotImplementedError가 일어나게 된다.

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss
        '''
        return loss with encoder hidden vector
        return {"loss": loss, "hiddens":z,} 
        '''
		

여기서부터 기존에 보지 못한 녀석이 나오게 된다.
training_step은 학습 로직을 가지고 있다.
추론 부터 loss 계산을 마치고 리턴값으로 backward를 진행할 loss를 반환해야 한다.
반환되는 값의 경우 loss 이외 다른 값을 같이 반환하기 위해서는 python 딕셔너리를 사용해 "loss" 키값으로 loss를 같이 반환해야 학습이 진행된다.
training_step은 입력으로 배치 데이터, 현재 진행 중인 배치 인덱스가 들어온다. 이외에도 옵티마이져 인덱스를 받을 수 있다.

training_step의 자세한 내용은 공식문서를 참고

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

configure_optimizers는 이름처럼 학습에 사용되는 최적화 함수를 설정한다.
configure_optimizers의 내용은 스케줄링, 최적화 함수 개수에 따라 다양한 출력 방식이 있으므로 공식문서를 참고

LightningModule + PL.Trainer

dataset = MNIST(
    os.getcwd(),
    download=True,
    transform=transforms.ToTensor(),
)
train_loader = DataLoader(dataset)

학습에 사용할 데이터 설정 pytorch에서 쓰는 방식을 그대로 가지고 왔지만 PL의 경우 LightningDataModule이라는 데이터 셋의 묶음이 있다. 다른 글에서 이 내용을 다루고자 한다.
여기서는 pytorch의 dataloader를 그대로 사용할 수 있다는 것을 보여주고자 한다.

# model 생성
autoencoder = LitAutoEncoder()

앞에서 작성한 AutoEncoder 모델을 생성

# callback 설정 
callbacks = [
    TQDMProgressBar(refresh_rate=5),
    LearningRateMonitor(logging_interval='epoch'),
]

학습에 사용할 callback 설정 Keras를 사용해봤다면 익숙한 방식이다.
개인적으로 선호하지 않지만, callback으로 처리하면 편한 부분이 있기 때문에 이렇게 사용하자...

# trainer 생성
trainer = pl.Trainer(
        logger=wandb_logger, # logger 설정
        gpus=1,				 # 학습에 사용할 gpu 가속기 설정
        max_epochs=epochs,	 # 학습 epochs 설정
        callbacks=callbacks, # 학습중 사용될 callback list 설정
    )

# training
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

Trainer는 PL의 큰 특징이며 코드관리를 편하게 만들어 주는 포인트인 거 같다.

머신러닝 프로젝트를 진행할 때 개인적인 능력과 경험이 부족해 모델과 학습 코드를 분리하기 쉽지 않은 경우가 많았다.

나름 레이어를 잡아 분리한다고 해도 어디선가 종속성이 생겨 결국 학습 설정(ex: multi gpu, AMP, 학습 루핑 변경, 로거 변경, 로깅방식 변경 등등...)이 변경되면 학습 루핑을 담당하는 전체 코드를 훑으면서 코드를 변경했다.

실수가 없으면 다행이지만 변경사항이 많아 어디선가 실수를 하게 되면 짜증이 올라오기 시작한다. (보통 이런 경우 에러를 뿌리지 않기 때문에 뒤늦게 확인하는 경우가 많음...)

특히 금요일 저녁에 학습을 돌려놓고 토요일 저녁에 hyperparams config에 오류(주로 코드 변경하다 까먹어 설정 변경을 하지 않은 경우...)가 있다는 걸 확인하면 주말이 고스란히 날아가는 경우가 생기게 된다.

그런점에서 PL의 Trainer는 아주 좋다고 생각했다. 학습설정을 코드변경이 거의 없이 설정의 변경만으로 바꿀 수 있기 때문에 실험을 관리하는 입장에서 신경 써야 할 부분이 줄어들어 실수를 줄일 수 있다고 생각한다.

Trainer는 다양한 옵션이 존재하고 다양한 학습 방식을 결정할 수 있다.

  • logger (tb, mlflow, wandb, neptune, comet, custom_logger...)
  • 가속기 설정 (tpu, gpu, ipu ...)
  • Gradient Accumulation
  • Distributed Training
  • AMP
  • Checkpoint
  • ...

Significant & Limitation

Pytorch Lightning에 대한 간략한 소개를 진행했다. PL을 사용하면서 느낀 개인적인 장점은

  • 머신러닝 코드의 추상화와 구조화를 통해 OOP에 친숙한 개발자들에게 가속성이 좋은 코드를 만들 수 있도록 도와줌
  • 최소한의 코드변경으로 다양한 학습 방식을 사용할 수 있음
  • 2번의 연장선이지만 코드변경이 최소화되면서 실수를 줄일 수 있음 (막을 수 있다는 게 아니다!)

지금까지 발견한 단점은

  • notebook에 익숙하다면 조금 난해할 수도 있음
  • keras와 달리 pytorch를 모르는 사람이 사용하기엔 힘들 수 있음
  • pytorch에서 한 번 더 래핑이 되기 때문에 디버깅이 조금 귀찮을 수 있음
  • PL 코드 내부를 한번 보지 않으면 내가짠 코드가 어떤 식으로 동작 파악하는데 어려움이 있을 수 있음
  • 학습 시 동작하는 다양한 hook이 있는데 종류가 많다. 다음 글에선 이 부분을 다뤄보고자 한다.
  • 옵션으로 config dict를 넣어 사용하는 부분이 보이는데 처음 사용하면 이 부분이 난해 할수 있음

PL 도입을 추천하는 경우

  • pytorch를 사용해 모델을 짜는데 코드를 구조적으로 관리해보고자 하는 사람들
  • 파라미터가 아닌 학습 옵션들을 빈번하게 변경해야 하는 경우

PL 도입을 추천하지 않는 경우

  • pytorch를 사용해보지 않았을 때
  • 모델 구조가 지나치게 복잡한 경우
  • 학습 코드의 추상화를 원하지 않는 경우

Ending

연구실에 있을 땐 GAN을 사용한 이미지생성을 주로 진행했다.
당시에도 코드 관리의 필요성을 느껴 나름 추상클래스와 추상메서드를 사용해 공용 상속받아 사용할 수 있는 학습 템플릿 모듈을 만들어 진행했다.
하지만 학습 루프에서 종속성을 깔끔하게 분리하지 못해 이런저런 난항이 많았다.
당시 PL 도입을 생각했지만, 시간이 부족하다는 핑계와 GAN 구조상 PL에서 training_step이 깔끔하게 나누어지지 않다는 이유로 도입을 미루고 있었다.
시간이 생겨 밤마다 이미지 도메인에서 분류모델을 구현하며 PL을 도입해 사용해 만족해하고 있지만, 아직도 GAN에서 PL도입이 좋은 선택일까 고민하고 있다. 이 부분은 조만간 코드를 작성하며 생각해 보려 한다.

profile
지식보다 지혜를

0개의 댓글