
def evaluate_plot(metric, train_metrics, train_losses, valid_metrics, valid_losses):
'''evaluate_plot function
author : d4r6j
'''
figure, ax1 = plt.subplots()
color1 = "tab:orange"
color2 = "tab:green"
ax1.set_title(f"Training and Validation {metric} & Loss")
ax1.set_xlabel("epoch")
ax1.set_ylabel(f"{metric}")
ax1.plot(range(1, len(train_metrics) + 1)
, train_metrics
, label=f"Training {metric}"
, marker='.'
, linestyle='-'
, color=color1)
ax1.plot(range(1, len(valid_metrics) + 1)
, valid_metrics
, label=f"Validation {metric}"
, marker='.'
, linestyle='-'
, color=color2)
ax1.legend(bbox_to_anchor=(0.3, -0.1), prop={'size': 8})
color3 = "tab:purple"
color4 = "tab:blue"
ax2 = ax1.twinx()
ax2.set_ylabel("Cross Entropy")
ax2.plot(range(1, len(train_losses) + 1)
, train_losses
, label="Training Loss"
, marker='.'
, linestyle='-.'
, color=color3)
ax2.plot(range(1, len(valid_losses) + 1)
, valid_losses
, label="Validation Loss"
, marker='.'
, linestyle='-.'
, color=color4)
valid_min_loss = valid_losses.index(min(valid_losses)) + 1
ax2.axvline(valid_min_loss, linestyle='--', color='r', label="Early Stopping Checkpoint")
ax2.legend(bbox_to_anchor=(1, -0.1), prop={'size': 7})
figure.tight_layout()
plt.show()
date_time = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
figure.savefig(f"{date_time}_{metric}_Plot.png")