Today I Learned
written by 602
aka. 코드 주석 깔끔하게 다는 법
위 4가지 원칙에 따라 수어생성 프로젝트(text2keypoint) 코드를 정리하였다.
여러개의 파일이 있는데 그 중 train.py
의 MyCallback
class를 가져왔다.
"""
Set callbacks and execute training process.
"""
class MyCallback(tf.keras.callbacks.Callback):
"""
Set the custom callback class
to print learning rate every 5 epochs and save model and prediction every 10 epochs.
(Given periods are just for the example data(train, dev, test each are 5). Freely change values as you want.)
:method __init__: call required parameter
:method on_epoch_end: execute at the end of each epoch
"""
def __init__(self, name, cfg, X, y, decoder_input_array,
output_file, output_gloss, output_skels):
"""
Call required parameters.
:param inheritance from default keras callbacks(ex. epoch)
:param self.cfg: configuration dictionary
:param self.X: X_dev from data.py
:param self.y: y_dev from data.py
:param self.decoder_input_array: decoder_input_array_dev from data.py
:param self.output_file: output_file_dev
:param self.output_gloss: output_gloss_dev
:param self.output_skels: output_skels_dev
"""
super().__init__() #inheritance from tf.keras.callbacks.Callback
# additional params
self.cfg = cfg
self.X = X
self.y = y
self.decoder_input_array = decoder_input_array
self.output_file = output_file
self.output_gloss = output_gloss
self.output_skels = output_skels
def on_epoch_end(self, epoch, logs=None):
"""
executed in the end of every epochs.
"""
# Print learning rate every 5 epochs
if epoch > 0 and epoch % 5 == 0:
lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
print('learning rate : ', lr)
# Save model and prediction every 10 epochs
if epoch > 0 and epoch % 10 == 0:
model_path = self.cfg["model_path"]
self.model.save(model_path+"model.h5")
result_path = self.cfg["result_path"]
make_predict(self.cfg, self.model, self.X, self.y, self.decoder_input_array,
self.output_file, self.output_gloss, self.output_skels,
result_path, epoch, best=False)
확실히 이전보다 더 전문적이고 깔끔해보인다!