Early stopping

김혁·2022년 10월 10일
0

Early stopping 이란?

딥러닝을 한 가지 중요한 딜레마로는 너무 많은 Epoch 은 overfitting 을 일으킵니다. 하지만 너무 적은 Epoch 은 underfitting 을 일으킵니다. 이런 상황에서 Epoch을 어떻게 설정해야 좋을까요? Epoch 을 정하는데 많이 사용되는 Early stopping 은 무조건 Epoch 을 많이 돌린 후, 특정 시점에서 멈추는 것입니다. 그 특정시점을 어떻게 정하느냐가 Early stopping 의 핵심이라고 할 수 있습니다. 데이터셋을 불러와서 코드를 작성해보고 Early stopping을 알아보도록 하겠습니다.

  1. 데이터셋을 불러오는 코드는 생략하고 간단한 MLP를 선언하겠습니다.
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128,10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # image input을 펼쳐준다.
        x = x.view(-1, 28*28)

        # 은닉층을 추가하고 활성화 함수로 relu 사용
        x = F.relu(self.fc1(x))
        x = self.dropout(x)

        # 은닉층을 추가하고 활성화 함수로 relu 사용
        x = F.relu(self.fc2(x))
        x = self.dropout(x)

        # 출력층 추가
        x = self.fc3(x)
        return x

# initialize the NN / 모델 확인
model = Net()
print(model)
Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
  1. loss function과 optimizer를 선언하겠습니다.
# loss function의 정의(CrossEntropyLoss)
criterion = nn.CrossEntropyLoss()

# optimizer 정의(Adam)
optimizer = torch.optim.Adam(model.parameters())
  1. EarlyStopping 클래스를 early stopping을 참고했던 원저자가 작성한 pytorchtools 활용하여 불러와야하는데, colab환경에서 직접 사용하느라 여기에서 EarlyStopping를 정의하였습니다.
class EarlyStopping:
    """주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): validation loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
            path (str): checkpoint저장 경로
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''validation loss가 감소하면 모델을 저장한다.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss
  1. Early stopping을 사용하여 model을 train 하겠습니다
import numpy as np 

def train_model(model, batch_size, patience, n_epochs):

    # 모델이 학습되는 동안 trainning loss를 track
    train_losses = []
    # 모델이 학습되는 동안 validation loss를 track
    valid_losses = []
    # epoch당 average training loss를 track
    avg_train_losses = []
    # epoch당 average validation loss를 track
    avg_valid_losses = []

    # early_stopping object의 초기화
    early_stopping = EarlyStopping(patience = patience, verbose = True)

    for epoch in range(1, n_epochs + 1):

        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        for batch, (data, target) in enumerate(train_loader, 1):
            # clear the gradients of all optimized variables
            optimizer.zero_grad()    
            # forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
            output = model(data)
            # calculate the loss
            loss = criterion(output, target)
            # backward pass: 모델의 파라미터와 관련된 loss의 그래디언트 계산
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())


        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        for data , target in valid_loader :
            # forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
            output = model(data)
            # calculate the loss
            loss = criterion(output, target)
            # record validation loss
            valid_losses.append(loss.item())

        # print 학습/검증 statistics
        # epoch당 평균 loss 계산
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(n_epochs))


        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')

        print(print_msg)

        # clear lists to track next epoch
        train_losses = []
        valid_losses = []

        # early_stopping는 validation loss가 감소하였는지 확인이 필요하며,
        # 만약 감소하였을경우 현제 모델을 checkpoint로 만든다.
        early_stopping(valid_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

   # best model이 저장되어있는 last checkpoint를 로드한다.
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses
batch_size = 256
n_epochs = 100

train_loader, test_loader, valid_loader = create_datasets(batch_size)

# early stopping patience;
# validation loss가 개선된 마지막 시간 이후로 얼마나 기다릴지 지정
patience = 20

model, train_loss, valid_loss = train_model(model, batch_size, patience, n_epochs)   
[ 60/100] train_loss: 0.08084 valid_loss: 0.09212
EarlyStopping counter: 14 out of 20
[ 61/100] train_loss: 0.08203 valid_loss: 0.09326
EarlyStopping counter: 15 out of 20
[ 62/100] train_loss: 0.08309 valid_loss: 0.09803
EarlyStopping counter: 16 out of 20
[ 63/100] train_loss: 0.07713 valid_loss: 0.09040
EarlyStopping counter: 17 out of 20
[ 64/100] train_loss: 0.07699 valid_loss: 0.09509
EarlyStopping counter: 18 out of 20
[ 65/100] train_loss: 0.07717 valid_loss: 0.09056
EarlyStopping counter: 19 out of 20
[ 66/100] train_loss: 0.07682 valid_loss: 0.09458
EarlyStopping counter: 20 out of 20
Early stopping

validation loss가 줄어들면 저장을 하고 저장된 validation loss보다 줄지 않으면 Early stopping counter가 증가하고 20이 되었을 때 학습이 중지 된 것을 확인할 수 있습니다.

  1. 시각화를 이용해서 plot에서 모델이 과적합되기 직전에 마지막 Early Stopping Checkpoint가 저장됨을 알 수 있습니다.
import matplotlib.pyplot as plt

# 훈련이 진행되는 과정에 따라 loss를 시각화
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')

# validation loss의 최저값 지점을 찾기
minposs = valid_loss.index(min(valid_loss))+1
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')

plt.xlabel('epochs')
plt.ylabel('loss')
plt.ylim(0, 0.5) # 일정한 scale
plt.xlim(0, len(train_loss)+1) # 일정한 scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('loss_plot.png', bbox_inches = 'tight')

참고 : https://github.com/Bjarten/early-stopping-pytorch

profile
군도리

0개의 댓글