이번 튜토리얼에서는 single gpu training code로 부터 multi gpu training을 위해 어떤 부분이 바뀌고 추가되는지에 대한 내용을 다룬다. 코드는 다음 github repo에서 확인할 수 있다.
- Note
만약 모델이 BatchNorm layer를 가지고 있다면, 모델의 모든 BatchNorm layer를 SyncBatchNorm으로 바꿔줘야 한다. 이를 위한 helper function으로 pytorch는 torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 함수를 제공한다.
아래는 DDP를 위해 추가로 import한 모듈들이다.
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
def ddp_setup(rank: int, world_size: int):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost" # machine의 IP 주소. single machine이라 localhost 입력
os.environ["MASTER_PORT"] = "12355" # 임의의 free port number
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
MASTER_ADDR은 rank0 process가 실행되는 곳의 주소이기도 하다.
self.model = DDP(model, device_ids=[gpu_id])
model.to(device)를 사용하던 것 대신, DDP를 사용.
train_data = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=False, # shuffle 사용 x
sampler=DistributedSampler(train_dataset),
)
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
self.train_data.sampler.set_epoch(epoch) # set_epoch 호출
for source, targets in self.train_data:
...
self._run_batch(source, targets)
ckp = self.model.module.state_dict() # model이 DDP 객채로 wrap되어 module로 호출해야함.
...
...
if self.gpu_id == 0 and epoch % self.save_every == 0: # 하나의 process를 사용하도록 조건 추가
self._save_checkpoint(epoch)
def main(rank, world_size, total_epochs, save_every):
ddp_setup(rank, world_size)
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size=32)
trainer = Trainer(model, train_data, optimizer, rank, save_every)
trainer.train(total_epochs)
destroy_process_group() # 모든 process 종료
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)