early stopping

d4r6j·2023년 8월 23일

ml modeling

목록 보기
2/5
post-thumbnail

REF : https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

Training 할 때, validation acc/loss 등 target 지표를 비교하고 더 이상 학습이 필요 없겠다고 판단 하면, epoch (학습) 을 끝내는 방법론.

import numpy as np
import torch

class EarlyStopping:

    # ... 중략 ...
   
    def __call__(self, val_loss, model):
        score = val_loss

        if self.best_score is None:
            self.best_score = score
            self.SaveCheckpoint(val_loss, model)

        elif score > self.best_score - self.delta:
            self.counter += 1
            self.trace_func(f'Validation loss increased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Not saving model ...')
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.counter = 0
            self.best_score = score
            self.SaveCheckpoint(val_loss, model)
  1. 설정 이후 처음 validation loss 는 best_score 로 저장한다.
  2. 다음 validation loss 가 저장된 best_score 를 비교
    • 작으면 best_score 를 교체
    • 아니면 self.counter 를 늘려서 self.patience 가 될 때까지 loop.

def SaveCheckpoint(self, val_loss, model):
    if self.verbose:
        self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
    
    torch.save(model.state_dict(), self.path)
    self.val_loss_min = val_loss

끝나면 SaveCheckpoint 에 마지막으로 저장된 모델로 끝나게 된다.


def Training(self
    , device: str
    , model
    , epoch: int
    , lr: float
    , train_data
    , val_data
    , early_stop: bool = False
    , save_path: str = "savemodel.pt"):
    
    if early_stop is True:
    earlyStopping = EarlyStopping(patience=3, verbose=True, path=save_path)
    
    optimizer = optim.Adam(model.parameters(), lr = lr)
    criterion = torch.nn.CrossEntropyLoss()
        
    for epoch in range(epoch):

Training 시에 early stopping 셋팅을 하게 되 각 epoch 에서 validation 을 할 때,


# validation 
val_progress = tqdm.tqdm(iterable=val_data
    , bar_format='{l_bar}{bar:25}{r_bar}'
    , ascii=True
    , total=len(val_data)
    , leave=True)

with torch.no_grad():
    model.eval()

    # ... 중략 ...

    if early_stop is True:
        earlyStopping(val_loss, model)
        
        if earlyStopping.early_stop is True:
            print("Early stopping")
            break

early stopping 에서 early_stop 이 True 가 되면 위와 같이 epoch 의 loop 를 break 로 빠져나오게 된다.

0개의 댓글