[pytorch] Pytorch-lightning으로 mnist 구현 코드 분석(4) - 모델 학습(CNN-Convolutional Neural Network)

hye0n.gyu·2024년 4월 27일

DeepLearning

목록 보기
5/6
post-thumbnail

✔ Training Model 정의

def model_train(dataloader, model, loss_function, optimizer):

    model.train() # 신경망을 학습모드로 전환

    train_loss_sum = train_correct = train_total = 0

    total_train_batch = len(dataloader)

    for images, labels in dataloader: #images -MNIST 이미지, labels -0~10 정답 

        x_train = images.to(DEVICE) -images.shape = (batch_size,1,28,28)
        y_train = labels.to(DEVICE)

        outputs = model(x_train) - 입력데이터에 대한 예측값 계산
        loss = loss_function(outputs, y_train) - 모델의 예측값과 정답과의 손실 계산

        optimizer.zero_grad() # 역전파 코드( optimizer.step()까지 ). 학습이 진행됨에 따라 모델 파라미터(가중치, 바이어스) 업데이트하며 최적화
        loss.backward()
        optimizer.step()

        train_loss_sum += loss.item()

        train_total += y_train.size(0)
        train_correct += ((torch.argmax(outputs, 1)==y_train)).sum().item()

    train_avg_loss = train_loss_sum / total_train_batch # 학습 데이터 평균 오차 계산
    train_avg_accuracy = 100*train_correct / train_total # 학습 데이터 평균 정확도 계산

    return (train_avg_loss, train_avg_accuracy)

✔ evaluate Model 정의

evaluate Model에는 역전파 코드가 없다. 이유는 현재 상태를 확인만, 즉, 현재 오차와 정확도만을 계산하는 것이 evaluate Model이기 때문에 쓰지 않는다.

def model_evaluate(dataloader, model, loss_function, optimizer):

 model.eval() # 신경망을 추론(검증)모드로 전환

 with torch.no_grad(): # 미분을 하지 않도록 하는 코드, 즉 모델파라미터를 업데이터하지 않겠다는 의미

     val_loss_sum = val_correct = val_total = 0

     total_val_batch = len(dataloader)

     for images, labels in dataloader:

         x_val = images.to(DEVICE)
         y_val = labels.to(DEVICE)

         outputs = model(x_val)
         loss = loss_function(outputs, y_val)

         val_loss_sum += loss.item()

         val_total += y_val.size(0)
         val_correct += ((torch.argmax(outputs, 1)==y_val)).sum().item()

     val_avg_loss = val_loss_sum / total_val_batch
     val_avg_accuracy = 100*val_correct / val_total

 return (val_avg_loss, val_avg_accuracy)

✔ test Model 정의

def model_test(dataloader, model):

 model.eval()

 with torch.no_grad():

     test_loss_sum = test_correct = test_total = 0

     total_test_batch = len(dataloader)

     for images, labels in dataloader:

         x_test = images.to(DEVICE)
         y_test = labels.to(DEVICE)

         outputs = model(x_test)
         loss = loss_function(outputs, y_test)

         test_loss_sum += loss.item()

         test_total += y_test.size(0)
         test_correct += ((torch.argmax(outputs, 1)==y_test)).sum().item()

     test_avg_loss = test_loss_sum / total_test_batch
     test_avg_accuracy = 100*test_correct / test_total

     print('accuracy:', test_avg_accuracy)
     print('loss:', test_avg_loss)

✔ model training


from datetime import datetime

train_loss_list = []
train_accuracy_list = []

val_loss_list = []
val_accuracy_list = []

start_time = datetime.now()

EPOCHS = 20

for epoch in range(EPOCHS):

  #==============  model train  ================
  train_avg_loss, train_avg_accuracy = model_train(train_dataset_loader, model, loss_function, optimizer)

  train_loss_list.append(train_avg_loss)
  train_accuracy_list.append(train_avg_accuracy)
  #=============================================

  #============  model evaluation  ==============
  val_avg_loss, val_avg_accuracy = model_evaluate(validation_dataset_loader, model, loss_function, optimizer)

  val_loss_list.append(val_avg_loss)
  val_accuracy_list.append(val_avg_accuracy)
  #============  model evaluation  ==============

  print('epoch:', '%02d' % (epoch + 1),
        'train loss =', '{:.3f}'.format(train_avg_loss), 'train acc =', '{:.3f}'.format(train_avg_accuracy),
        'val loss =', '{:.3f}'.format(val_avg_loss), 'val acc =', '{:.3f}'.format(val_avg_accuracy))

end_time = datetime.now()

print('elapsed time => ', end_time-start_time)

출처: https://github.com/neowizard2018/neowizard/blob/master/PyTorch/PyTorch_LEC16_CNN_FashionMNIST_Example_CPCP_9262.ipynb

profile
반려묘 하루 velog

0개의 댓글