여러 클래스의 객체를 탐지할 수 있도록 설계된 모델
ex) 자율주행 차량이 자동차, 사람, 자전거, 신호등을 동시에 탐지
다양한 객체를 처리하기 위해 모델 구조가 더 복잡하기에 단일 클래스 OD에 비해 속도가 느리거나 정확도가 낮을 수 있음
다양한 클래스 간 균형 잡힌 학습이 성능평가의 관건임
| 특성 | 단일 클래스 OD | 멀티 클래스 OD |
|---|---|---|
| 탐지 대상 | 한 가지 객체 클래스 | 여러 객체 클래스 |
| 모델 구조 | 단순, 경량화 | 복잡, 다양한 클래스 처리 가능 |
| 속도 | 빠름 | 비교적 느림 |
| 정확도 | 특정 클래스에 최적화, 높은 정확도 | 클래스 간 균형이 중요 |
unzip -qq aquarium.zip -d ./aquarium
# 샘플 데이터 출력
import os
import os.path as op
import mmcv
import matplotlib.pyplot as plt
MMDET_PATH = "/content/mmdetection"
AQUA_PATH = op.join(MMDET_PATH, "aquarium")
fname = os.listdir(op.join(AQUA_PATH, "train"))[100]
samp_img = mmcv.imread(op.join(AQUA_PATH, "train", fname))
plt.figure(figsize=(7,7))
plt.imshow(mmcv.bgr2rgb(samp_img))
plt.axis(False)
plt.show()
mim download mmdet --config faster-rcnn_r50_fpn_1x_coco --dest ./aquarium
# config 파일 로드
from mmengine.config import Config
cfg = Config.fromfile(op.join(AQUA_PATH, "faster-rcnn_r50_fpn_1x_coco.py"))
# config 파일 수정
cfg.metainfo = {
"classes": ("creatures", "fish", "jellyfish", "penguin", "puffin", "shark", "starfish", "stingray"),
"palette": [
(220, 20, 60), (255, 179, 0), (219, 242, 255), (40, 27, 134),
(224, 186, 183), (150, 69, 50), (255, 171, 194), (255, 102, 204)
]
}
cfg.data_root = "./aquarium"
cfg.train_dataloader.dataset.ann_file = "train/_annotations.coco.json"
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.train_dataloader.dataset.data_prefix.img = "train/"
cfg.train_dataloader.dataset.metainfo = cfg.metainfo
cfg.val_dataloader.dataset.ann_file = "valid/_annotations.coco.json"
cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_prefix.img = "valid/"
cfg.val_dataloader.dataset.metainfo = cfg.metainfo
cfg.test_dataloader.dataset.ann_file = "test/_annotations.coco.json"
cfg.test_dataloader.dataset.data_root = cfg.data_root
cfg.test_dataloader.dataset.data_prefix.img = "test/"
cfg.test_dataloader.dataset.metainfo = cfg.metainfo
cfg.val_evaluator.ann_file = "./aquarium/valid/_annotations.coco.json"
cfg.test_evaluator.ann_file = "./aquarium/test/_annotations.coco.json"
cfg.model.roi_head.bbox_head.num_classes = 8
cfg.load_from = "./aquarium/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
cfg.work_dir = "./aquarium/work_dir"
cfg.train_cfg.val_interval = 3
cfg.default_hooks.checkpoint.interval = 3
cfg.optim_wrapper.optimizer.lr = 0.02/8 # 1개 GPU 사용
cfg.default_hooks.logger.interval = 10
cfg.seed = 0
cfg.visualizer.vis_backends.append({"type": "TensorboardVisBackend"})
with open("./aquarium/aquarium_config.py", "w") as f:
f.write(cfg.pretty_text)
python tools/train.py ./aquarium/aquarium_config.py
import tensorboard
%load_ext tensorboard
%tensorboard --logdir ./aquarium/work_dir
import mmcv
from mmdet.apis import init_detector, inference_detector
fname = os.listdir("./aquarium/test")[0]
img = mmcv.imread(os.path.join("./aquarium/test", fname), channel_order="rgb")
checkpoint_file = "./aquarium/work_dir/epoch_12.pth"
model = init_detector(cfg, checkpoint_file, device="cuda:0")
res = inference_detector(model, img)
print(res)
from mmdet.registry import VISUALIZERS
vis = VISUALIZERS.build(model.cfg.visualizer)
vis.dataset_meta = model.dataset_meta
vis.add_datasample(
name="aquarium",
image=img,
data_sample=res,
draw_gt=False,
wait_time=0,
out_file=None,
pred_score_thr=0.5,
)
vis.show()
*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.