Plot for Evaluation

d4r6j·2023년 9월 18일

ml modeling

목록 보기
4/5
post-thumbnail


def evaluate_plot(metric, train_metrics, train_losses, valid_metrics, valid_losses):
	'''evaluate_plot function
    
    author : d4r6j

    '''

    figure, ax1 = plt.subplots()

    color1 = "tab:orange"
    color2 = "tab:green"

    ax1.set_title(f"Training and Validation {metric} & Loss")
    ax1.set_xlabel("epoch")

    ax1.set_ylabel(f"{metric}")
    ax1.plot(range(1, len(train_metrics) + 1)
            , train_metrics
            , label=f"Training {metric}"
            , marker='.'
            , linestyle='-'
            , color=color1)
    ax1.plot(range(1, len(valid_metrics) + 1)
            , valid_metrics
            , label=f"Validation {metric}"
            , marker='.'
            , linestyle='-'
            , color=color2)
    ax1.legend(bbox_to_anchor=(0.3, -0.1), prop={'size': 8})

    color3 = "tab:purple"
    color4 = "tab:blue"
    ax2 = ax1.twinx()

    ax2.set_ylabel("Cross Entropy")
    ax2.plot(range(1, len(train_losses) + 1)
            , train_losses
            , label="Training Loss"
            , marker='.'
            , linestyle='-.'
            , color=color3)
    ax2.plot(range(1, len(valid_losses) + 1)
            , valid_losses
            , label="Validation Loss"
            , marker='.'
            , linestyle='-.'
            , color=color4)
    valid_min_loss = valid_losses.index(min(valid_losses)) + 1
    ax2.axvline(valid_min_loss, linestyle='--', color='r', label="Early Stopping Checkpoint")
    ax2.legend(bbox_to_anchor=(1, -0.1), prop={'size': 7})
    
    figure.tight_layout()
    plt.show()

    date_time = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
    figure.savefig(f"{date_time}_{metric}_Plot.png")

0개의 댓글