Object detecion 과정에서 하나의 객체에 대해 여러개의 Bbox가 중복해서 나타날 수 있다. 중복해서 나타는 Bbox중 가장 정확하다고 생각하는 하나의 Bbox를 선택하는 알고리즘이다.
특히 Default box를 사용하는 one-stage 알고리즘에서 주로 나타난다(하나의 셀에 여러개의 Bbox를 찾도록 정의되어서 그렇다).
NMS 알고리즘의 수도 코드는 다음과 간다.
While Bboxes:
Take out the largest probability box (확률은 높지만 IoU가 작은 bbox를 살리기 위해서)
Remove all other bboxes with IoU > threshold
(And we do this for each class)
먼저, 예측된 Bbox들 중 가장 높은 확률로 예측한 하나의 박스를 선택한다. 이 Bbox를 기준으로 나머지 Bbox와의 IoU를 계산한다. 이때, iou_theathold
보다 낮은 Bbox들을 제거한다.
정리하면,
1. 예측된 Bbox들 중 가장 높은 Probability의 Bbox를 선택 -> target_bbox
2. target_bbox
와 나머지 bbox들의 IoU를 계산 후 iou_threshold
보다 높은 bbox를 선택 -> result
3. 위 과정을 모든 class에 대해서 반복해서 수행
속도를 위해 개선해야 할 여지가 있지만 이는 나중에 해보도록 하겠다.
import torch
from iou import intersection_over_union
def non_maximum_suppression(bboxes, prob_threshold, iou_threshhold):
# prediction: [[cls, prob, cx, cy, w, h]] -> (N, 5)
bboxes = [bbox for bbox in bboxes if bbox[1] > prob_threshold]
# remove all bbox < probability threshold
bboxes = sorted(bboxes, key=lambda bbox:bbox[1], reverse=True)
# sort descending order by bbox probability
result = []
while bboxes:
target_bbox = bboxes.pop(0)
# select largest probaility bbox
bboxes = [
bbox for bbox in bboxes
if (bbox[0] != target_bbox[0]) or # ignore another class
(intersection_over_union( \
torch.tensor(target_bbox[2:]), \
torch.tensor(bbox[2:])) < iou_threshhold)
# ignore bbox that is lower that iou_threshold
]
result.append(target_bbox)
return result