오늘은 pytorch DistributedDataParallel(DDP)를 이용해서 multi-gpu 학습을 하는 방법을 알아봐요.
먼저 필요한 모듈을 불러와요.
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
필요한 함수들을 정의해요.
def setup():
dist.init_process_group("nccl")
def cleanup():
dist.barrier()
dist.destroy_process_group()
이제 기존 모델 코드에 DDP를 감싸면 돼요.
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[device])
그리고 기존 데이터셋에 DistributedSampler를 적용해줘요.
이 때, shuffle=False로 해줘야 된답니다.
train_data = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=False, # 이 부분은 False로 해줘야해요!
sampler=DistributedSampler(train_dataset),
)
epoch를 시작할 때 train_dataloader.sampler.set_epoch(epoch)를 적용해요.
for epoch in epochs:
train_dataloader.sampler.set_epoch(epoch)
마지막으로, setup()과 cleanup() 함수를 적용해주면 돼요.
def main():
setup() # 시작!
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size=32)
trainer = Trainer(model, train_data, optimizer, device, save_every)
trainer.train(total_epochs)
cleanup() # 끝!
이제 gpu수에 맞춰서 nproc_per_node를 설정하고 파일을 torchrun으로 실행시키면 돼요.
torchrun --standalone --nproc_per_node=4 train.py
참고
https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py