저번엔 데이터 셋을 불러오고 간단한 CNN 모델을 만드는 것까지 해보았다. 지금부터는 로드해온 데이터셋을 모델에 넣어 어떻게 학습하는지에 대해서 알아보려고 한다.
코딩하는 것이 중요하다기 보단 왜 이렇게 코딩을 했는지를 우선시해서 보면 좋을 것 같다.
## train.py
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import matplotlib.pyplot
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Train 함수의 매개변수는 사용자마다 다르다
def Train(model, train_DL, criterion, optimizer, EPOCH):
loss_history = [] # loss를 담을 그릇을 준비한다
NoT = len(train_DL.dataset)
model.train() # train mode로 변경해야 한다.
for ep in range(EPOCH):
rloss = 0
for x_batch, y_batch in train_DL:
x_batch = x_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
y_hat = model(x_batch)
loss = criterion(y_hat, y_batch) # (모델 예측값, 실제값) 순서 매우 중요
optimizer.zero_grad() # 기울기를 누적하므로 0으로 초기화해줘야 한다. 안그럼 계속 누적됨
loss.backward() # BackPropagation 진행
optimizer.step() # 최신화
loss_b = loss.item() * x_batch.shape[0]
# criterion에서 나온 값은 이미 x_batch의 갯수만큼 나뉘어서 loss.item()에 들어감.
# 그 이유는 nn.CrossEntropyLoss는 이미 softmax함수가 적용되어 있기 때문이다.
# 아래의 링크를 확인하라
rloss += loss_b
loss_e = rloss/NoT
loss_history += [loss_e]
print(f"EPOCH : {ep + 1}, train_loss : {round(loss_e, 3}"
print("-" * 20)
return loss_history
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
train 함수는 이게 끝이다. 절반만 말이다. 일반적으로 train을 할때 validataion도 같이 진행한다. 학습되지 않은 데이터들을 보고 얼마나 이 모델이 잘 학습되고 있는지를 판단하기 위해선 train loss보단 validation loss가 잘 떨어지는가가 가장 중요하다. validation loss를 구하는 건 그리 어렵지 않다. 위의 코드를 복붙하면 되기 때문이다. 지금은 valiation loss를 구현하지 않고 그냥 넘어가고 후에 전체 코드를 올리려 한다.
잘 학습되고 있는지 test를 하기 위한 test 함수도 만들어보자.
def Test(model, test_DL):
model.eval() # 모델을 평가 모드로 바꿔줘야 한다.
with torch.no_grad(): # 학습단계가 아니니 기울기를 구할 이유가 없다.
r_correct = 0
for x_batch, y_batch in test_DL:
x_batch = x_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
y_hat = model(x_batch)
pred = y_hat.argmax(dim=1) # 가장 큰 원소의 인덱스를 반환한다.
corrects_b = torch.sum(pred == y_batch).item()
r_correct += corrects_b
accuracy_e = r_correct / len(test_DL.dataset) * 100
print(f"Test accuracy: {r_correct}/{len(test_DL.dataset)} ({round(accuracy_e)}%)"
이게 test 함수이다.
코딩 자체의 어려움은 없다. 하지만 이제부터 난이도가 올라가는 것들이 많다.
무슨 말이냐 하면 분류는 그나마 간단한 task라는 것이다. Object Detection task는 여러 과정을 하나의 loss로 평가해야하고 이를 빠르게 추론하기 위한 task이기 때문에 이런 분류 문제보단 매우 어려운 task 중의 하나이다. 앞으로 나와 같이 이 Object Detction도 같이 여행해볼 예정이다. 같이 기대하자