1. model loading
model parameter를 GPU로 보낸다.
현재 memory : model
2. forward pass
input을 model에 보내고, 중간(intermediate) output(==activation)을 적재한다. gradient checkpointing 사용 시, 모든 activation을 저장하지 않고 필요한 activation만 저장할 수 있다.
현재 memory : model + actications
3. backward pass
network의 끝에서 시작부분까지 gradient를 연산한다. 이후 forward activation은 제거한다. activation을 제거하면 memory에는 model size의 2배의 memory가 저장되어 있다; 하나는 model weight의 복사본, 나머지 하나는 gradient의 복사본이다.
현재 memory : model + gradients
4. optimizer step
parameter를 업데이트, running optimizer parameter의 trak을 유지한다. 많은 optimizer들은 gradient의 first, second momentum을 측정함으로 parameter들의 track을 유지한다. Adam의 경우 2 momentum을 사용하므로 model size의 2배를 소모하고, RMSProp의 경우 1 momentum을 사용하므로 model size만큼의 memory 양을 소모할 것이다.
현재 memory : model + gradients + gradient moments
5. run the next iterations
처음 iteration이 끝나면 gradient와 moments가 남아있다. 따라서 2번째 step 이후부터는 maximum memory usage가 model + activations + gradients + gradient moments가 된다.
model weight, gradient들을 full 32-bit precision으로 저장하지만, 사용시에는 half precision으로 forward pass를 수행한다. 이는 forward pass 시 사용되는 memory 양을 절반으로 줄일 수 있다.
forward pass 시 사용되는 memory 양을 아낄 수 있는 또 다른 방법으로 학습시간을 조금 더 쓰는 대신 메모리를 아낄 수 있다. forward pass에서 산출되는 activation들을 모두 저장하는 것이 아니라 특정 값들만 저장 후 중간 gradient 연산을 다시 수행하면서 저장 하지 않은 값들을 복원하는 방법이다.
DDP 사용시, GPU들에 model을 load할 때, model size의 2배 만큼을 사용하는 것을 볼 수 있는데, 이는 "bucket"을 각 GPU에 만들어 다른 GPU로부터 gradient들을 모으는데 사용하기 때문이다. 따라서 DDP 사용시에는 gradient의 copy가 하나 더 사용된다.
inference시에는, 학습때와는 다르게 activation, gradient, ...를 저장, 연산할 필요가 없다. 따라서 torch.no_grad()를 사용해 inference를 수행하도록 하자.