문제
torch.distributed.barrier()
는 distributed training (multi-gpu training) 환경에서 multi-process로 학습을 수행할 때, 각 process (rank)들마다 진행 속도가 다를 수 있다.
- 모든 process들이 barrier()에 도달할 때까지 wait()을 걸어줌으로써 sync를 맞춰주는 역할을 한다.
- process가 도달을 하지 않거나, sync가 맞지 않으면 무한 대기에 빠진다.
- 모델 학습할 때 프로세스가 대강 아래와 같은데, epoch loop의
torch.distributed.barrier()
에서 stall되는 문제가 있었다.
pseudo code
for loop:
for loop:
...
output = model(input)
...
loss.backward
...
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
...
validation()
...
torch.save()
...
torch.distributed.barrier()
Validation code (rank=0만)
def validation(self):
acc1 = 0.0
acc5 = 0.0
losses = []
with torch.no_grad():
for _, data in enumerate(tqdm(self.val_loader)):
input = data[0].cuda()
gt = data[1].cuda()
outputs = self.encoder(input)
loss = self.CELoss(outputs, gt)
losses.append(loss)
acc1_, acc5_ = self.accuracy(outputs, gt, topk=(1, 5))
acc1 += acc1_
acc5 += acc5_
acc1 = acc1 / len(self.val_loader)
acc5 = acc5 / len(self.val_loader)
loss = sum(losses) / len(losses)
return acc1, acc5, loss
해결
- 다음 링크에 따르면,
torch.distributed.get_rank() == 0
에 조건문을 걸고 모델을 실행시키더라도, 모델이 DistributedDataParallel()
로 wrap되어있을 때 특정 상황(?)에서 forward pass때 sync를 대기하는 현상이 있다고 한다.
- 아래와 같이 validation() 함수에서
self.encoder.module(input)
으로 변경하여 해결하였음.
def validation(self):
acc1 = 0.0
acc5 = 0.0
losses = []
with torch.no_grad():
for _, data in enumerate(tqdm(self.val_loader)):
input = data[0].cuda()
gt = data[1].cuda()
outputs = self.encoder.module(input)
loss = self.CELoss(outputs, gt)
losses.append(loss)
acc1_, acc5_ = self.accuracy(outputs, gt, topk=(1, 5))
acc1 += acc1_
acc5 += acc5_
acc1 = acc1 / len(self.val_loader)
acc5 = acc5 / len(self.val_loader)
loss = sum(losses) / len(losses)
return acc1, acc5, loss