distributed training에서 하나의 process failure는 전체 process failure로 이어진다. 높아진 failure 위험으로 인해 robust한 training script가 중요하다.
Pytorch는 elastic training과 fault tolerance 기능이 있는 torchrun을 제공한다. torchrun은 error가 발생하면 log를 남기고, 자동으로 저장된 snapshot 시점에서 부터 학습을 다시 시작한다.
snapshot에는 모델의 state dict 뿐만 아니라, snapshot 시점의 epoch, optimizer state 등의 state들이 저장된다.
torchrun을 사용하면 사용자는 distributed training의 사소한 부분을 신경쓰지 않아도 된다.
graceful restart를 위해 다음의 code가 필요하다.
def main():
load_snapshot(snapshot_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_snapshot(snapshot_path)
이전 DDP training code와 달리 mp.spaw을 호출하지 않고, rank와 world_size를 지정하지 않는다. 그리고 학습 iter마다 snapshot을 저장한다. 만약 failure가 발생하면 torchrun은 모든 process를 종료하고 snapshot 지점부터 재시작한다.
elastic training 중, node가 추가되거나 제거되는 등의 mebership에 대한 변경이 발생할 때마다 torchrun은 모든 process를 중단하고 학습 가능한 모든 자원을 활용하여 학습을 재시작 한다.
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
따라서 init_process_group에 rank와 word_size를 전달해 줄 필요가 없고, 추가로 master 주소와 port 또한 지정할 필요가 없다. 현재 rank는 LOCAL_RANK 환경변수로 부터 얻을 수 있다.
self.gpu_id = int(os.environ["LOCAL_RANK"])
gpu_id도 LOCAL_RANK 사용.
일반적으로 모든 학습에 관련된 정보를 snapshot에 저장할 수 있고, 이는 interruption 이후 학습이 매끄럽게 재시작 될 수 있도록 만들어준다.
def _save_snapshot(self, epoch):
snapshot = {}
snapshot["MODEL_STATE"] = self.model.module.state_dict()
snapshot["EPOCHS_RUN"] = epoch
torch.save(snapshot, "snapshot.pt")
print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")
def _load_snapshot(self, snapshot_path):
snapshot = torch.load(snapshot_path)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
snapshot은 python dictionary 형태로 저장할 수 있다.
interruption 이후 학습이 재시작 하기 전 snapshot의 정보를 불러와야한다.
class Trainer:
def __init__(self, snapshot_path, ...):
...
if os.path.exists(snapshot_path):
self._load_snapshot(snapshot_path)
...
init 함수에 snapshot 존재 유무를 파악하여 존재하면 snapshot을 불러올 수 있도록 하는 코드를 작성한다.
재시작을 이전에 멈춘 시점의 epoch에서 시작하도록 코드를 작성한다.
def train(self, max_epochs: int):
for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
snapshot에는 현재 epcoh 상태를 함께 저장하고, 재시작시 마지막으로 저장된 epoch 상태에서부터 학습이 다시 진행된다.
multi-processing을 사용하지 않는 것처럼 entrypoint function(__main__)을 작성하고 이를 호출하면 된다. torhcrun이 자동으로 multi-process를 spaw 한다.
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
main(save_every, total_epochs)
torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10
nproc_per_node는 사용할 gpu의 수(process의 수)를 나타낸다.