[PyTorch] 자주 쓰이는 train 관련 구문

조성운·2023년 3월 30일
0

PyTorch

목록 보기
1/2
post-thumbnail

훈련 시 반복문을 통해 dataloader에 접근할 때 pytorch에서는 다음과 같은 부분이 자주 쓴인다.

img, label = img.float().to(device), label.long().to(device)

# 1. 그래디언트 값 초기화
optimizer_ft.zero_grad()

# 2, forward : model에 데이터 입력
pred_logit = model_finetune(img)

# 3. loss 값 계산
loss = criterion(pred_logit, label)

# 4. Backpropagation
loss.backward() # loss 값 계산이 되고
optimizer_ft.step() # 계산된 loss값을 업데이트 한다.

# Accuracy 계산
pred_label = torch.argmax(pred_logit, 1)
acc = (pred_label == label).sum().item() / len(img)

train_loss = loss.item()
train_acc = acc

pred_logit = model_finetune(img)

  • model에 데이터를 입력하고 출력값을 pred_logit 변수에 담는다.
  • 실제 pred_logit의 shape은 [bs, num_classes] 이다.
  • bs(batch size) 하나 마다 num_classes 개수만큼의 0과 1사이의 값이 담겨 있다.

pred_label = torch.argmax(pred_logit, 1)

  • bs 마다 num_classes의 확률값 중에서 가장 max한 값의 index를 저장한다.
  • 실제 pred_label의 shape은 [bs] 이다.
  • 각 bs 별로 num_classes 개수만큼의 값들 중에서 가장 큰 값, 즉 가장 큰 확률값을 가지는 인덱스에 해당하는 클래스가 해당 데이터를 보고 모델이 예측한 값이 된다.

acc = (pred_label == label).sum().item() / len(img)

  • pred_label == label 구문을 통해 인덱스 값이 동일한 경우만 True 값으로 나오도록 한다.
  • 이후 .sum()을 통해 True 개수 즉 실제 모델이 맞춘 개수의 합을 구한다.
  • 이후 /len(img) 이미지 개수만큼 나눠 평균 Accuray를 계산한다.

valid_loss, valid_acc = AverageMeter(), AverageMeter()

valid_loss.update(loss.item(), len(img))
valid_acc.update(acc, len(img))

profile
일단 적을게요

0개의 댓글