MMDetection (3)

Myeongsu Moon·2024년 12월 15일
0

제로베이스

목록 보기
38/95
post-thumbnail

Chapter5 다중 class Object 모델 학습

데이터 준비

  • 사용할 데이터셋
    아쿠아리움 데이터

  • 아쿠아리움 데이터 다운받기

  • 로컬 PC에 다운로드 후 구글드라이브에 업로드

  • 코랩 환경에서 데이터 가져오기

!unzip '.다운로드 파일 경로/Aquarium Combined.v2-raw-1024.coco-mmdetection.zip' -d './Aquarium'
  • 예제 데이터 확인
import mmcv
import matplotlib.pyplot as plt

img = mmcv.imread('/content/Aquarium/test/IMG_2289_jpeg_jpg.rf.fe2a7a149e7b11f2313f5a7b30386e85.jpg')
plt.figure(figsize=(15,10))
plt.imshow(mmcv.bgr2rgb(img))
plt.show()

데이터 학습

  • 8개 class가 데이터에 들어있음
!mim download mmdet --config faster-rcnn_r50_fpn_1x_coco --dest ./Aquarium_project
  • config 파일 가져오기
from mmengine import Config

cfg = Config.fromfile('./Aquarium_project/faster-rcnn_r50_fpn_1x_coco.py')
  • config 파일 수정
cfg.metainfo = {
    'classes' : ('creatures','fish','jellyfish','penguin','puffin','shark','starfish','stingray',),
    'palette' : [
        (220, 20, 60), (255, 179, 0), (219, 242, 255), (40, 27, 134), (224, 186, 183), (150, 69, 50), (255, 171, 194), (255, 102, 204),
    ]
}
cfg.data_root = './Aquarium'

cfg.train_dataloader.dataset.ann_file = 'train/_annotations.coco.json'
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.train_dataloader.dataset.data_prefix.img = 'train'
cfg.train_dataloader.dataset.metainfo = cfg.metainfo

cfg.val_dataloader.dataset.ann_file = 'valid/_annotations.coco.json'
cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_prefix.img = 'valid'
cfg.val_dataloader.dataset.metainfo = cfg.metainfo

cfg.test_dataloader.dataset.ann_file = 'test/_annotations.coco.json'
cfg.test_dataloader.dataset.data_root = cfg.data_root
cfg.test_dataloader.dataset.data_prefix.img = 'test'
cfg.test_dataloader.dataset.metainfo = cfg.metainfo
cfg.val_evaluator.ann_file = './Aquarium/valid/_annotations.coco.json'
cfg.test_evaluator.ann_file = './Aquarium/test/_annotations.coco.json'

cfg.model.roi_head.bbox_head.num_classes = 8

cfg.load_from = './Aquarium_project/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

cfg.work_dir = './Aquarium_project/work_dir'
cfg.train_cfg.val_interval = 3
cfg.default_hooks.checkpoint.interval = 3

cfg.optim_wrapper.optimizer.lr = 0.02 / 8
cfg.default_hooks.logger.interval = 10

cfg.seed = 0

cfg.visualizer.vis_backends.append({"type":'TensorboardVisBackend'})
  • config 파일 저장
with open('./Aquarium_project/Aquarium_config.py', 'w') as f:
  f.write(cfg.pretty_text)
  • 학습
!python tools/train.py ./Aquarium_project/Aquarium_config.py
  • tensorboard로 학습결과 확인
import tensorboard
%load_ext tensorboard
%tensorboard --logdir './Aquarium_project/work_dir' --port=8888
  • 새로운 이미지로 모델 테스트
import mmcv
from mmdet.apis import init_detector, inference_detector

img = mmcv.imread('./Aquarium/test/IMG_2632_jpeg_jpg.rf.f44037edca490b16cbf06427e28ea946.jpg', channel_order='rgb')

checkpoint_file = './Aquarium_project/work_dir/epoch_12.pth'
model = init_detector(cfg, checkpoint_file, device='cuda:0')
new_result = inference_detector(model, img)

print(new_result)
from mmdet.registry import VISUALIZERS

visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
visualizer.add_datasample(
    'new_result',
    img,
    data_sample=new_result,
    draw_gt = False,
    wait_time = 0,
    out_file=None,
    pred_score_thr=0.5
)

visualizer.show()

이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다

0개의 댓글