
REF : https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
Training 할 때, validation acc/loss 등 target 지표를 비교하고 더 이상 학습이 필요 없겠다고 판단 하면, epoch (학습) 을 끝내는 방법론.
import numpy as np
import torch
class EarlyStopping:
# ... 중략 ...
def __call__(self, val_loss, model):
score = val_loss
if self.best_score is None:
self.best_score = score
self.SaveCheckpoint(val_loss, model)
elif score > self.best_score - self.delta:
self.counter += 1
self.trace_func(f'Validation loss increased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Not saving model ...')
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.counter = 0
self.best_score = score
self.SaveCheckpoint(val_loss, model)
def SaveCheckpoint(self, val_loss, model):
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
끝나면 SaveCheckpoint 에 마지막으로 저장된 모델로 끝나게 된다.
def Training(self
, device: str
, model
, epoch: int
, lr: float
, train_data
, val_data
, early_stop: bool = False
, save_path: str = "savemodel.pt"):
if early_stop is True:
earlyStopping = EarlyStopping(patience=3, verbose=True, path=save_path)
optimizer = optim.Adam(model.parameters(), lr = lr)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(epoch):
Training 시에 early stopping 셋팅을 하게 되 각 epoch 에서 validation 을 할 때,
# validation
val_progress = tqdm.tqdm(iterable=val_data
, bar_format='{l_bar}{bar:25}{r_bar}'
, ascii=True
, total=len(val_data)
, leave=True)
with torch.no_grad():
model.eval()
# ... 중략 ...
if early_stop is True:
earlyStopping(val_loss, model)
if earlyStopping.early_stop is True:
print("Early stopping")
break
early stopping 에서 early_stop 이 True 가 되면 위와 같이 epoch 의 loop 를 break 로 빠져나오게 된다.