[OD] Non-max Suppression (NMS)

Sinaenjuni·2024년 1월 10일
0

Object detection

목록 보기
2/4

Non-max Suppression이란?

Object detecion 과정에서 하나의 객체에 대해 여러개의 Bbox가 중복해서 나타날 수 있다. 중복해서 나타는 Bbox중 가장 정확하다고 생각하는 하나의 Bbox를 선택하는 알고리즘이다.

특히 Default box를 사용하는 one-stage 알고리즘에서 주로 나타난다(하나의 셀에 여러개의 Bbox를 찾도록 정의되어서 그렇다).

Pseudo-code of NMS

NMS 알고리즘의 수도 코드는 다음과 간다.

While Bboxes:
	Take out the largest probability box (확률은 높지만 IoU가 작은 bbox를 살리기 위해서)
	Remove all other bboxes with IoU > threshold
(And we do this for each class)

먼저, 예측된 Bbox들 중 가장 높은 확률로 예측한 하나의 박스를 선택한다. 이 Bbox를 기준으로 나머지 Bbox와의 IoU를 계산한다. 이때, iou_theathold보다 낮은 Bbox들을 제거한다.

정리하면,
1. 예측된 Bbox들 중 가장 높은 Probability의 Bbox를 선택 -> target_bbox
2. target_bbox와 나머지 bbox들의 IoU를 계산 후 iou_threshold보다 높은 bbox를 선택 -> result
3. 위 과정을 모든 class에 대해서 반복해서 수행

Implemetation with Pytorch

속도를 위해 개선해야 할 여지가 있지만 이는 나중에 해보도록 하겠다.

import torch
from iou import intersection_over_union

def non_maximum_suppression(bboxes, prob_threshold, iou_threshhold):
    # prediction: [[cls, prob, cx, cy, w, h]] -> (N, 5)

    bboxes = [bbox for bbox in bboxes if bbox[1] > prob_threshold]
    # remove all bbox < probability threshold
    bboxes = sorted(bboxes, key=lambda bbox:bbox[1], reverse=True)
    # sort descending order by bbox probability

    result = []
    while bboxes:
        target_bbox = bboxes.pop(0) 
        # select largest probaility bbox

        bboxes = [
            bbox for bbox in bboxes
            if (bbox[0] != target_bbox[0]) or # ignore another class
                (intersection_over_union( \
                	torch.tensor(target_bbox[2:]), \
                    torch.tensor(bbox[2:])) < iou_threshhold) 
                # ignore bbox that is lower that iou_threshold
        ]
        result.append(target_bbox)
    
    return result

0개의 댓글