앞에서 만든 신경망을 파이토치로 구현했다.
from torchvision.datasets import FashionMNIST
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch
import torch.optim as optim
# 데이터 다운로드
fm_train = FashionMNIST(root='.', train=True, download=True)
fm_test = FashionMNIST(root='.', train=False, download=True)
train_input = fm_train.data
train_target = fm_train.targets
# 정규화
train_scaled = train_input / 255.0
# 데이터 분리
train_scaled, val_scaled, train_target, val_target = train_test_split(train_scaled, train_target, test_size=0.2, random_state=42)
# 모델 생성
model = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 100),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(100, 10)
)
# gpu 사용
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 손실 함수
criterion = nn.CrossEntropyLoss()
# 옵티마이저
optimizer = optim.Adam(model.parameters())
train_hist = []
val_hist = []
patience = 2
best_loss = -1
early_stopping_counter = 0
epochs = 20
batches = int(len(train_scaled)//32)
for epoch in range(epochs):
model.train() # 모델 학습 모드
train_loss = 0
for i in range(batches):
inputs = train_scaled[i*32: (i+1)*32].to(device)
targets = train_target[i*32: (i+1)*32].to(device)
optimizer.zero_grad() # 옵티마이저 초기화
outputs = model(inputs) # input에 대한 예측값
loss = criterion(outputs, targets) # 실제 값과 비교하여 loss 계산
loss.backward() # 역전파
optimizer.step() # 가중치 이동
train_loss += loss.item() # loss 추가
model.eval() # 모델 평가 모드
val_loss = 0
with torch.no_grad():
val_scaled = val_scaled.to(device)
val_target = val_target.to(device)
outputs = model(val_scaled) # 검증 값 예측
loss = criterion(outputs, val_target) # 검증 값 실제 값이랑 손실 계산
val_loss += loss.item()
train_hist.append(train_loss/batches)
val_hist.append(val_loss/batches)
print(f"{epoch+1}/{epochs}, train_loss: {train_loss/batches:.4f}, val_loss: {val_loss:.4f}", end=' ')
# 모델이 더 좋으면 저장
if best_loss == -1 or val_loss < best_loss:
best_loss = val_loss
early_stopping_counter = 0
print('모델 저장')
else:
early_stopping_counter += 1
# patience동안 좋아지지 않으면 종료
if early_stopping_counter >= patience:
print(f'{epoch+1}번째에서 조기 종료')
break