classification tsak에서 label을 예측하지 못했더라도, 유의미한 예측일 때 측정하는 top-5 Accuracy를 짜보자!

prediction = [40, 25, 10, 80, 60, 20, 5, 30, 10, 15]
label = 2
n_pred = len(prediction)
sum_pred = 0

for pred in prediction:
    sum_pred += pred
print(sum_pred)

prob_prediction = list()
for _ in range(n_pred):
    prob_prediction.append((prediction[_]/sum_pred)*100)

sorted_prob_pred = list()
sorted_prob_pred_idx = list()

for _ in range(5):
    M, M_idx = None, None
    for idx, pred in enumerate(prob_prediction):
        if idx in sorted_prob_pred_idx:
            pass
        elif M == None or M < pred:
            M = pred
            M_idx = idx
    sorted_prob_pred.append(M)
    sorted_prob_pred_idx.append(M_idx)

print(sorted_prob_pred)
print(sorted_prob_pred_idx)

if label in sorted_prob_pred_idx:
    print("Good")
else:
    print("Bad")
    

결과

profile
공부한거 혼자 끄적끄적...

0개의 댓글