pytorch: multi-gpu 학습하는 방법

djlee·2023년 3월 20일
0

오늘은 pytorch DistributedDataParallel(DDP)를 이용해서 multi-gpu 학습을 하는 방법을 알아봐요.
먼저 필요한 모듈을 불러와요.

from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

필요한 함수들을 정의해요.

def setup():
    dist.init_process_group("nccl")

def cleanup():
    dist.barrier()
    dist.destroy_process_group()

이제 기존 모델 코드에 DDP를 감싸면 돼요.

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[device])

그리고 기존 데이터셋에 DistributedSampler를 적용해줘요.
이 때, shuffle=False로 해줘야 된답니다.

train_data = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False, # 이 부분은 False로 해줘야해요!
    sampler=DistributedSampler(train_dataset),
)

epoch를 시작할 때 train_dataloader.sampler.set_epoch(epoch)를 적용해요.

for epoch in epochs:
  train_dataloader.sampler.set_epoch(epoch)

마지막으로, setup()과 cleanup() 함수를 적용해주면 돼요.

def main():
   setup() # 시작!
   dataset, model, optimizer = load_train_objs()
   train_data = prepare_dataloader(dataset, batch_size=32)
   trainer = Trainer(model, train_data, optimizer, device, save_every)
   trainer.train(total_epochs)
   cleanup() # 끝!

이제 gpu수에 맞춰서 nproc_per_node를 설정하고 파일을 torchrun으로 실행시키면 돼요.

torchrun --standalone --nproc_per_node=4 train.py

참고

https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py

0개의 댓글