def get_ann(img_id, coco):
return coco.getAnnIds(imgIds=img_id)
def get_img(img_id, coco):
base_path = '/opt/ml/detection/dataset'
img_path = os.path.join(base_path, coco.loadImgs(img_id)[0]['file_name'])
image = cv2.imread(img_path)
return image
def draw_bboxes(img_id, coco, ax):
anns = get_ann(img_id, coco)
image = get_img(img_id, coco)
for ann in anns:
x, y, w, h = map(int, coco.loadAnns(ann)[0]['bbox'])
label = coco.loadCats(coco.loadAnns(ann)[0]['category_id'])[0]['name']
image = cv2.rectangle(image, (x, y), (x + w, y + h), (36,255,12), 3)
cv2.putText(image, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 4)
ax.imshow(image)
n_rows, n_cols = 4, 4
fig, ax = plt.subplots(n_rows, n_cols, sharex=True, sharey=True, figsize=(20, 20))
index = list(map(int, np.random.randint(1, len(train_json['images'])-1, size=n_cols*n_rows)))
for i, idx in enumerate(index):
draw_bboxes(idx, coco, ax[i%n_rows][i//(n_cols)])
ax[i%n_rows][i//(n_cols)].set_title(idx)
plt.tight_layout()
plt