Detectron2 사용법

J. Hwang·2024년 10월 4일
0

Detectron2는 Facebook AI Research의 PyTorch를 기반으로 object detection, segmentation 등에 널리 쓰이는 딥러닝 라이브러리이다.
object detection은 classification 뿐만 아니라 물체가 이미지 내에서 어느 위치에 있는지까지 알아내야 하는 작업이다보니 직접 구현하려면 굉장히 복잡하고 어렵다. 이 때 MMDetection이나 Detectron2를 이용하면 configuration 설정을 하는 것만으로도 object detection을 수행할 수 있다.

Import

import os
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data import register_coco_instances

import detectron2.data.transforms as T

from detectron2.evaluation import COCOEvaluator
from detectron2.data import build_detection_test_loader, build_detection_train_loader

Configuration file 다루기

# configuration file 불러오기
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file('COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml'))

# dataset 설정
cfg.DATASETS.TRAIN = ("coco_trash_train")
cfg.DATASETS.TEST = ("coco_trash_val")

# 학습 설정
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url('COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml')
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.MAX_ITER = 3000
cfg.SOLVER.STEPS = (1000, 1500)
cfg.SOLVER.GAMMA = 0.05
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 10
cfg.TEST.EVAL_PERIOD = 500

Augmentation mapper 정의

MMDetection과 달리 Detectron2는 라이브러리 내에서 augmentation 기능을 지원하지 않는다. 따라서 augmentation (및 데이터 전처리) 을 수행하는 함수를 따로 작성해서 적용해야 한다.

def MyMapper(dataset_dict):
    dataset_dict = copy.deepcopy(dataset_dict)
    image = utils.read_image(dataset_dict['file_name'], format='BGR')
    
    transform_list = [
        T.RandomFlip(prob=0.5, horizontal=False, vertical=True),
        T.RandomBrightness(0.8, 1.8),
        T.RandomContrast(0.6, 1.3)
    ]
    
    image, transforms = T.apply_transform_gens(transform_list, image)
    
    dataset_dict['image'] = torch.as_tensor(image.transpose(2,0,1).astype('float32'))
    
    annos = [
        utils.transform_instance_annotations(obj, transforms, image.shape[:2])
        for obj in dataset_dict.pop('annotations')
        if obj.get('iscrowd', 0) == 0
    ]
    
    instances = utils.annotations_to_instances(annos, image.shape[:2])
    dataset_dict['instances'] = utils.filter_empty_instances(instances)
    
    return dataset_dict

Dataset 정의

# train dataset 정의
register_coco_instances('coco_trash_train', {}, '/home/data/train.json', '/home/data')
# validation dataset 정의
register_coco_instances('coco_trash_val', {}, '/home/data/val.json', '/home/data')

# 메타데이터 등록 (선택)
classes = ["General Trash", "Paper", "Paper pack", "Metal", "Glass", "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing"]

MetadataCatalog.get('coco_trash_train').set(thing_classes=classes)
MetadataCatalog.get('coco_trash_val').set(thing_classes=classes)

register_coco_instances의 인자

  • 첫번째 인자 : dataset의 이름을 지정한다.
  • 두번째 인자 : dataset과 관련된 메타데이터가 담긴 딕셔너리 (클래스 정보 등)
  • 세번째 인자 : COCO 형식의 annotation file
  • 네번째 인자 : 이미지가 저장된 디렉토리 경로

학습

Trainer 클래스를 정의한 후 학습시킨다.

class MyTrainer(DefaultTrainer):
    @classmethod
    def build_train_loader(cls, cfg, sampler=None):
        return build_detection_train_loader(
        cfg, mapper = MyMapper, sampler = sampler
        )
    
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            os.makedirs('./output_eval', exist_ok = True)
            output_folder = './output_eval'
            
        return COCOEvaluator(dataset_name, cfg, False, output_folder)
        
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

trainer = MyTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

Custom Backbone 모델 등록하기

Detectron2 상에서 다양한 모델을 이용할 수 있는데, 원하는 모델이 없고 직접 구성해서 사용하고 싶다면 custom backbone 모델을 등록해서 사용할 수 있다.

from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec

@BACKBONE_REGISTRY.register()
class MyBackbone(Backbone):
	def __init__(self, cfg, input_shape):
    	super().__init__()
        pass
    def forward(self, image):
    	pass
    def output(self):
    	pass
        
cfg = get_cfg()
cfg.MODEL.BACKBONE.NAME = 'MyBackbone'
model = build_model(cfg)

References

https://github.com/facebookresearch/detectron2

profile
Let it code

0개의 댓글