torch.save() + torch.load() 조합을 기본으로 하며, 가중치만 저장할 경우 state_dict()를 사용하여 가중치를 추가로 추출하고 로드하는 작업이 필요import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchsummary as ts
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, 1)
self.conv2 = nn.Conv2d(32, 64 , 3, 1, 1)
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(64 * 7 * 7, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = self.dropout(x)
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = F.log_softmax(self.fc2(x), dim=1)
return x
cnn = CNN().to(device)
print(cnn)
train_data = datasets.MNIST(
"./data",
train=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
),
)
test_data = datasets.MNIST(
"./data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
),
)
train_loader = DataLoader(
train_data,
shuffle=True,
batch_size=256,
)
test_loader = DataLoader(
test_data,
shuffle=False,
batch_size=256,
)
optimizer = optim.Adam(cnn.parameters(), lr=0.03)
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
patience=2, # 2 epoch 동안 손실이 감소되지 않는다면,
factor=0.1, # 학습률을 기존의 0.1 수준으로 줄임
)
for epoch in range(10):
cnn.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = cnn(images)
loss = F.nll_loss(outputs, labels)
loss.backward()
optimizer.step() # 가중치 조정
print(f"[Epoch {epoch+1}] Train Loss: {loss.item():.5f}")
scheduler.step(loss) # 학습률 조정
cnn.eval()
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = cnn(images)
loss = F.nll_loss(outputs, labels)
print(f"[Epoch {epoch+1}] Test Loss: {loss.item():.5f}")
# [Epoch 1] Train Loss: 0.47822
# [Epoch 1] Test Loss: 0.06771
# [Epoch 2] Train Loss: 0.24769
# [Epoch 2] Test Loss: 0.07236
# [Epoch 3] Train Loss: 0.24033
# [Epoch 3] Test Loss: 0.04768
# [Epoch 4] Train Loss: 0.13009
# [Epoch 4] Test Loss: 0.00878
# [Epoch 5] Train Loss: 0.14712
# [Epoch 5] Test Loss: 0.00319
# [Epoch 6] Train Loss: 0.18090
# [Epoch 6] Test Loss: 0.01081
# [Epoch 7] Train Loss: 0.19649
# [Epoch 7] Test Loss: 0.00811
# [Epoch 8] Train Loss: 0.06929
# [Epoch 8] Test Loss: 0.00624
# [Epoch 9] Train Loss: 0.07021
# [Epoch 9] Test Loss: 0.00735
# [Epoch 10] Train Loss: 0.10174
# [Epoch 10] Test Loss: 0.00636]
# 가중치 저장
torch.save(cnn.state_dict(), "./model/cnn_weights.pt")
# 모델 생성
cnn_only_weights = CNN().to(device)
# 가중치 로드
cnn_only_weights.load_state_dict(torch.load("./model/cnn_weights.pt", weights_only=True))
# 모델 저장 (가중치 포함)
torch.save(cnn, "./model/cnn_model.pt")
# 모델 로드
cnn_model = torch.load("./model/cnn_model.pt", weights_only=False)
*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.