Data Parallel (multi-thread):
모든 GPU에 모델 복사
→ 각 GPU에서 gradient 계산
→ Gradient를 하나의 GPU에 모아 업데이트
Distributed Data Parallel (multi-process):
모든 GPU에 모델 복사
→ 각 GPU에서 gradient 계산
→ Averaged gradient를 통해 모든 GPU에서 각각 업데이트
Full-Sharded Data Parallel (multi-process):
모든 GPU에 모델 layer-sharding
→ 다른 GPU의 shard 정보를 임시 복사해 all-gather 연산으로 순전파
→ 다른 GPU의 shard 정보를 임시 복사해 all-gather 연산으로 역전파
→ 각 GPU에서 gradient 계산
→ Layer가 속해있던 원래의 GPU로 gradient를 전달해 reduce-scatter 연산
→ 각 GPU의 shard에 대한 gradient를 통해 모든 GPU에서 각각 업데이트