dist.get_rank() == 0 후 dist.barrier()에서 stall되는 현상

Chanyong·2023년 3월 29일
1

TIL

목록 보기
7/7

문제

  • torch.distributed.barrier()는 distributed training (multi-gpu training) 환경에서 multi-process로 학습을 수행할 때, 각 process (rank)들마다 진행 속도가 다를 수 있다.
    • 모든 process들이 barrier()에 도달할 때까지 wait()을 걸어줌으로써 sync를 맞춰주는 역할을 한다.
    • process가 도달을 하지 않거나, sync가 맞지 않으면 무한 대기에 빠진다.
  • 모델 학습할 때 프로세스가 대강 아래와 같은데, epoch loop의 torch.distributed.barrier()에서 stall되는 문제가 있었다.

pseudo code

for loop:	# epoch loop
	for loop:	# iteration loop
    	...
    	output = model(input)
        ...
        loss.backward
        ...
        torch.distributed.barrier()	# multi-gpu sync (1)
        
	if torch.distributed.get_rank() == 0:
    	...
        validation()
        ...
        torch.save()
        ...
	torch.distributed.barrier()	# multi-gpu sync (2)

Validation code (rank=0만)

    def validation(self):
        acc1 = 0.0
        acc5 = 0.0
        losses = []
        with torch.no_grad():
            for _, data in enumerate(tqdm(self.val_loader)):
                input = data[0].cuda()
                gt = data[1].cuda()

                outputs = self.encoder(input)
                loss = self.CELoss(outputs, gt)

                losses.append(loss)
                acc1_, acc5_ = self.accuracy(outputs, gt, topk=(1, 5))
                acc1 += acc1_
                acc5 += acc5_

        acc1 = acc1 / len(self.val_loader)
        acc5 = acc5 / len(self.val_loader)
        loss = sum(losses) / len(losses)

        return acc1, acc5, loss

해결

    def validation(self):
        acc1 = 0.0
        acc5 = 0.0
        losses = []
        with torch.no_grad():
            for _, data in enumerate(tqdm(self.val_loader)):
                input = data[0].cuda()
                gt = data[1].cuda()

                outputs = self.encoder.module(input)	# self.encoder(input) -> self.encoder.module(input)
                loss = self.CELoss(outputs, gt)

                losses.append(loss)
                acc1_, acc5_ = self.accuracy(outputs, gt, topk=(1, 5))
                acc1 += acc1_
                acc5 += acc5_

        acc1 = acc1 / len(self.val_loader)
        acc5 = acc5 / len(self.val_loader)
        loss = sum(losses) / len(losses)

        return acc1, acc5, loss
profile
AI Accelerator 연구

0개의 댓글