DDP 분산 데이터 병렬 처리

김민기·2023년 6월 4일
0
def Trainer:
	def __init__(self, args):
      self.ngpus_per_nodes = torch.cuda.device_count()	
      self.node_rank = args.rank	
      self.dist_backend = args.dist_backend	
      self.master_addr = os.environ.get("MASTER_ADDR", "localhost")	
      self.master_port = os.environ.get("MASTER_PORT", "8888")	
      self.dist_url = f"{args.dist_url}{self.master_addr}:{self.master_port}"

def init_distributed(self):
        if self.distributed:
            if torch.cuda.is_available():
                self.gpu    = self.local_rank % self.ngpus_per_node
                self.device = torch.device(self.gpu)
                if self.distributed:
                    self.local_rank = self.gpu
                    self.rank =  self.gpu
                    print(f'rank {self.rank} is running...')
                    dist.init_process_group(backend=self.dist_backend, init_method=self.dist_url,
                                            world_size=self.world_size, rank=self.rank)
                    dist.barrier()
                    self.setup_for_distributed(self.is_main_process())
        else:
            self.device = torch.device('cpu')
            
if self.distributed:
            self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True, static_graph=True)

for b, batch in enumerate(self.train_dataloader):
                for key in batch:
                    batch[key] = batch[key].to(self.local_rank)
                y = self.model(batch)

다수의 GPU를 사용하여 VRAM을 초과하는 batch_size를 선택했을 경우 사용할 수 있는 방법이다.

그냥 GPU 많이 써서 학습 시킨다고 생각하면 될듯

profile
work0ut

0개의 댓글