all_reduce all_gather
Just had the same problem and debugged it. You need to put
torch.cuda.set_device(rank) before dist.init_process_group()