def Trainer:
def __init__(self, args):
self.ngpus_per_nodes = torch.cuda.device_count()
self.node_rank = args.rank
self.dist_backend = args.dist_backend
self.master_addr = os.environ.get("MASTER_ADDR", "localhost")
self.master_port = os.environ.get("MASTER_PORT", "8888")
self.dist_url = f"{args.dist_url}{self.master_addr}:{self.master_port}"
def init_distributed(self):
if self.distributed:
if torch.cuda.is_available():
self.gpu = self.local_rank % self.ngpus_per_node
self.device = torch.device(self.gpu)
if self.distributed:
self.local_rank = self.gpu
self.rank = self.gpu
print(f'rank {self.rank} is running...')
dist.init_process_group(backend=self.dist_backend, init_method=self.dist_url,
world_size=self.world_size, rank=self.rank)
dist.barrier()
self.setup_for_distributed(self.is_main_process())
else:
self.device = torch.device('cpu')
if self.distributed:
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True, static_graph=True)
for b, batch in enumerate(self.train_dataloader):
for key in batch:
batch[key] = batch[key].to(self.local_rank)
y = self.model(batch)
다수의 GPU를 사용하여 VRAM을 초과하는 batch_size를 선택했을 경우 사용할 수 있는 방법이다.
그냥 GPU 많이 써서 학습 시킨다고 생각하면 될듯