딥러닝을 한 가지 중요한 딜레마로는 너무 많은 Epoch 은 overfitting 을 일으킵니다. 하지만 너무 적은 Epoch 은 underfitting 을 일으킵니다. 이런 상황에서 Epoch을 어떻게 설정해야 좋을까요? Epoch 을 정하는데 많이 사용되는 Early stopping 은 무조건 Epoch 을 많이 돌린 후, 특정 시점에서 멈추는 것입니다. 그 특정시점을 어떻게 정하느냐가 Early stopping 의 핵심이라고 할 수 있습니다. 데이터셋을 불러와서 코드를 작성해보고 Early stopping을 알아보도록 하겠습니다.
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128,10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# image input을 펼쳐준다.
x = x.view(-1, 28*28)
# 은닉층을 추가하고 활성화 함수로 relu 사용
x = F.relu(self.fc1(x))
x = self.dropout(x)
# 은닉층을 추가하고 활성화 함수로 relu 사용
x = F.relu(self.fc2(x))
x = self.dropout(x)
# 출력층 추가
x = self.fc3(x)
return x
# initialize the NN / 모델 확인
model = Net()
print(model)
Net(
(fc1): Linear(in_features=784, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=128, bias=True)
(fc3): Linear(in_features=128, out_features=10, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
)
# loss function의 정의(CrossEntropyLoss)
criterion = nn.CrossEntropyLoss()
# optimizer 정의(Adam)
optimizer = torch.optim.Adam(model.parameters())
class EarlyStopping:
"""주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
"""
Args:
patience (int): validation loss가 개선된 후 기다리는 기간
Default: 7
verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
Default: False
delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
Default: 0
path (str): checkpoint저장 경로
Default: 'checkpoint.pt'
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''validation loss가 감소하면 모델을 저장한다.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
import numpy as np
def train_model(model, batch_size, patience, n_epochs):
# 모델이 학습되는 동안 trainning loss를 track
train_losses = []
# 모델이 학습되는 동안 validation loss를 track
valid_losses = []
# epoch당 average training loss를 track
avg_train_losses = []
# epoch당 average validation loss를 track
avg_valid_losses = []
# early_stopping object의 초기화
early_stopping = EarlyStopping(patience = patience, verbose = True)
for epoch in range(1, n_epochs + 1):
###################
# train the model #
###################
model.train() # prep model for training
for batch, (data, target) in enumerate(train_loader, 1):
# clear the gradients of all optimized variables
optimizer.zero_grad()
# forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
output = model(data)
# calculate the loss
loss = criterion(output, target)
# backward pass: 모델의 파라미터와 관련된 loss의 그래디언트 계산
loss.backward()
# perform a single optimization step (parameter update)
optimizer.step()
# record training loss
train_losses.append(loss.item())
######################
# validate the model #
######################
model.eval() # prep model for evaluation
for data , target in valid_loader :
# forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
output = model(data)
# calculate the loss
loss = criterion(output, target)
# record validation loss
valid_losses.append(loss.item())
# print 학습/검증 statistics
# epoch당 평균 loss 계산
train_loss = np.average(train_losses)
valid_loss = np.average(valid_losses)
avg_train_losses.append(train_loss)
avg_valid_losses.append(valid_loss)
epoch_len = len(str(n_epochs))
print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
f'train_loss: {train_loss:.5f} ' +
f'valid_loss: {valid_loss:.5f}')
print(print_msg)
# clear lists to track next epoch
train_losses = []
valid_losses = []
# early_stopping는 validation loss가 감소하였는지 확인이 필요하며,
# 만약 감소하였을경우 현제 모델을 checkpoint로 만든다.
early_stopping(valid_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
# best model이 저장되어있는 last checkpoint를 로드한다.
model.load_state_dict(torch.load('checkpoint.pt'))
return model, avg_train_losses, avg_valid_losses
batch_size = 256
n_epochs = 100
train_loader, test_loader, valid_loader = create_datasets(batch_size)
# early stopping patience;
# validation loss가 개선된 마지막 시간 이후로 얼마나 기다릴지 지정
patience = 20
model, train_loss, valid_loss = train_model(model, batch_size, patience, n_epochs)
[ 60/100] train_loss: 0.08084 valid_loss: 0.09212
EarlyStopping counter: 14 out of 20
[ 61/100] train_loss: 0.08203 valid_loss: 0.09326
EarlyStopping counter: 15 out of 20
[ 62/100] train_loss: 0.08309 valid_loss: 0.09803
EarlyStopping counter: 16 out of 20
[ 63/100] train_loss: 0.07713 valid_loss: 0.09040
EarlyStopping counter: 17 out of 20
[ 64/100] train_loss: 0.07699 valid_loss: 0.09509
EarlyStopping counter: 18 out of 20
[ 65/100] train_loss: 0.07717 valid_loss: 0.09056
EarlyStopping counter: 19 out of 20
[ 66/100] train_loss: 0.07682 valid_loss: 0.09458
EarlyStopping counter: 20 out of 20
Early stopping
validation loss가 줄어들면 저장을 하고 저장된 validation loss보다 줄지 않으면 Early stopping counter가 증가하고 20이 되었을 때 학습이 중지 된 것을 확인할 수 있습니다.
import matplotlib.pyplot as plt
# 훈련이 진행되는 과정에 따라 loss를 시각화
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')
# validation loss의 최저값 지점을 찾기
minposs = valid_loss.index(min(valid_loss))+1
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.ylim(0, 0.5) # 일정한 scale
plt.xlim(0, len(train_loss)+1) # 일정한 scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('loss_plot.png', bbox_inches = 'tight')