Precision, Recall

마이클의 AI 연구소·2022년 2월 8일
0

개요

객체탐지 모델의 성능평가를 위해서는 객체가 존재하는 경우 잘 탐지하였는지, 존재하지 않을 때는 탐지하지 않는지를 확인해야 할 것입니다. 그러한 객체탐지 모델의 성능 평가에 사용되는 대표적인 평가지표인 정밀도와 재현율을 살펴봅시다.

먼저 그 지표를 도출하기 위한 추론결과를 다음과 같이 분류할 수 있습니다.

  • True Positive : True인 데이터를 True로 옳게 검출한 것
  • False Positive : False인 데이터를 True로 틀리게 검출한 것
  • False Negative : True인 데이터를 False로 틀리게 검출한 것
  • True Negative : False인 데이터를 False로 옳게 검출한 것
    (여기서 정답이란 모델이 검출하고자 하는 객체 클래스를 말함)

Precision 정밀도


정밀도는 모델이 True로 검출한 것 중에서 실제 True인 것의 비율입니다. 즉, 해당 객체가 존재한다고 말한 것 중에 얼마나 적중했는가를 말합니다.

정밀도만으로 정확도를 판단할 수 없습니다. 추론 회수자체가 적은 경우에는 정밀도가 높아질 수 있습니다. 예를들어, 50개의 True데이터와 50개의 False데이터에서 2개를 예측해서 2개 모두 True로 예측했다면 정밀도는 1이 됩니다. 또 다른 예로 True인 데이터를 False로 예측한 개수가 많더라도 정밀도에는 포함되지 않으므로 정확한 성능평가가 어려울 수 있습니다.

Recall 재현율


재현율은 실제 True인 데이터 중에 모델이 True라고 검출해낸 것의 비율입니다. 즉, 검출해내야 하는 타겟 클래스를 얼마나 다 검출했는가를 나타냅니다.

재현율만으로 정확도를 판단할 수 없습니다. 예를들어, 모든 데이터에 대해 무조건 True로 예측을 진행한다면 True인 데이터를 100% 검출해내는 것이 되므로 재현율이 1(100%)이 될 수 있습니다.

F1 Score

위와 같은 이유로 정밀도와 재현율을 개별적으로 사용하는 것은 좋지 않습니다. 두 지표를 적절히 혼합하여 사용하여 더욱 효과적인 지표를 만들어낼 필요가 있습니다. 그럴 때, F1 Score를 사용합니다.

F1 score는 정밀도와 재현율의 조화 평균입니다. 일반적인 평균을 구한다면 정밀도와 재현율을 더한 후, 2로 나누어줄텐데요. 위와 같이 정밀도와 재현율을 곱하는 것은 정밀도와 재현율이 둘중 하나가 0에 가까울 수록 F1 score도 동일하게 낮은 값을 갖도록 하기 위함입니다.

만일 재현율이 1이고 정밀도가 0.01인 모델이 있다고 했을 때, 일반 평균을 구한다면 (1+0.01)/2 = 0.505가 되지만, F1 score를 계산하면 2 (1 0.01) / (1+0.01) = 0.019가 되어 굉장히 낮은 값이 됩니다.

F1 score가 높다는 것은 두 수가 모두 고르게 높게 나오는 것을 의미하므로 저 종합적인 성능이 높은 것이라고 판단할 수 있습니다.

구현하기

def precision_recall(ious, gt_classes, pred_classes):
    """
    calculate precision and recall
    args:
    - ious [array]: NxM array of ious
    - gt_classes [array]: 1xN array of ground truth classes
    - pred_classes [array]: 1xM array of pred classes
    returns:
    - precision [float]
    - recall [float]
    """
    xs, ys = np.where(ious>0.5)

    # calculate true positive and true negative
    tps = 0
    fps = 0
    for x, y in zip(xs, ys):
        if gt_classes[x] == pred_classes[y]:
            tps += 1
        else:
            fps += 1

    matched_gt = len(np.unique(xs))
    fns = len(gt_classes) - matched_gt

    precision = tps / (tps+fps)
    recall = tps / (tps + fns)
    return precision, recall
profile
늘 성장을 꿈꾸는 자들을 위한 블로그입니다.

0개의 댓글