[DL] 모델 체크포인트

Sunguk·2023년 1월 12일
0

1. 모델 체크포인트

모델 체크포인트(Model Checkpoint)는 학습된 모델을 저장하는 기능입니다.

이를 사용하면 학습 중에 모델의 상태를 저장하고, 학습이 끝난 후에도 그 상태를 불러올 수 있습니다.

일반적으로 훈련 과정에서 일정 간격으로 저장하며, 이를 사용하면 학습 중에 일어난 에러나 실패를 복구할 수 있습니다.

2. 케라스에서의 사용방법

from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint('checkpoints/model.h5',
															save_best_only=True,
															save_weights_only=True,
															monitor='val_loss',
															mode='min',
															verbose=1)

hist = model.fit(
	                x_train,
	                y_train,
	                epochs=100,
	                batch_size=3,
	                validation_split=0.2,
	                verbose=1,
	                callbacks = [mcp]
                )

함수 선언 이후 model.fit 의 callback 함수의 인자 값으로 사용됩니다.

3. 인자값 설명

  • filepath: 저장할 파일 경로입니다. 파일명에 포함될 수 있는 파라미터는 epoch, acc, val_acc 등 입니다.
  • monitor: 모니터링 할 파라미터(metric)를 지정합니다.
  • save_best_only: 가장 좋은 모델만 저장할지 여부를 지정합니다.
  • mode: 'auto', 'min', 'max' 중 하나를 지정하면 그에 따라 monitor가 최소값/최대값일때만 저장합니다.

4. 시간별로 모델 저장하기

만약 최적의 모델 1개 가 아닌 에폭이 진행될때마다 모델을 기록하고싶다면?

import datetime
date = datetime.datetime.now()

# string for time 문자열로 바꿔
date =date.strftime("%y년 %m월 %d일 %H시 %m분")
print(type(date))
print(date)

filepath = "../_save/MCP/"

# 에폭과 val_loss 소수4째자리까지
filename = '{epoch:04d}-{val_loss:.4f}.hdf5'

mcp = ModelCheckpoint(
    monitor="val_loss",
    mode="auto",
    save_best_only=True,
    filepath= (filepath+date+filename),
    verbose=1 
)

datetime 으로 시간을 불러와서 파일명 이 각자 다르게 설정됩니다.

결론적으로 파일명 이 초기화되지 않아 시간때 별로 파일이 기록됩니다.

profile
안녕하세요

0개의 댓글