AMP는 Automatic Mixed Precision package의 약자로 모델의 single precision(FP32)를 두 종류의 precision(FP16, FP32)으로 학습하게 하여 빠르게 학습을 하게해주는 패키지이다.
모델의 Foward 연산은 서로 다른 두 행렬을 행렬 곱을 하는 것과 같다. 이 말은 각각의 두 행렬의 precision이 일치해야 동일한 값으로 간주하여 연산을 할 수 있다는 말이다. Input이 FP16이면 Operation을 FP16으로 진행하고(FP16으로 cast 가능한 연산에 한하여), Input이 FP32이면 Operation을 cast하지 않고 FP32으로 그대로 진행한다(또는 FP32로 연산할 수 없는 연산이면 그대로 FP32를 사용). Input이 FP16과 FP32로 혼용 되어 있는 상태이면, operation은 FP32로 연산하고 FP32 output 을 만든다.
⇒ FP16으로 연산할 수 있는 operation과 상황이 제한적으로 있으므로 mixed precision 이다.
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
optimizer.step()
⇒ 위와 같이 Backward 연산이 Forward시에 사용한 precision을 사용하면 underflow 가 발생할 수 있는 문제 점이 있음.
Forward pass를 FP16으로 하였다면, Backward pass 또한 FP16 gradient를 생성한다. Gradient의 값은 약간 크기 때문에 FP16으로 표현할 수 없을 수 있다. 이러한 현상을 underflow(0 이하의 값으로 넘친다)라고 표현한다. 이 현상을 해결하기 위해서 loss에 일정 크기의 scale factor를 곱해줌으로 써 backward pass시에 scaled loss로 연산을 하게하는 gradient scaling 을 한다.
Optimzier update를 하기 이전에 unscaled를 해주기 때문에, scale factor가 방해하는 요소로 작용하지 않는다.
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
scaler = torch.cuda.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)
# (retain_graph here is unrelated to amp, it's present because in this
# example, both backward() calls share some sections of graph.)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
# You can choose which optimizers receive explicit unscaling, if you
# want to inspect or modify the gradients of the params they own.
scaler.unscale_(optimizer0)
scaler.step(optimizer0)
scaler.step(optimizer1)
scaler.update()