MMDetection (4)

Myeongsu Moon·2024년 12월 17일

제로베이스

목록 보기
39/95
post-thumbnail

Chapter6 직접 코딩으로 해보는 Object detection 모델 학습

  • import
import torch
import torchvision
from torchvision import datasets, models
from torchvision.transforms import functional as FT
from torchvision import transforms as r
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split, Dataset
import copy
import math
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

from pycocotools.coco import COCO

import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

from collections import defaultdict, deque
import datetime
import time
from tqdm import tqdm
from torchvision.utils import draw_bounding_boxes
from torch.utils.tensorboard import SummaryWriter
  • 모듈 동작 확인
print(torch.__version__)
print(torchvision.__version__)
  • 데이터셋 (아쿠아리움 데이터)
!unzip '데이터 경로/Aquarium Combined.v2-raw-1024.coco.zip' -d '/content/Aquarium'
  • 변환함수
def get_transforms(train=False):
  if train :
    transforms = A.Compose([
        A.Resize(600,600),
        A.HorizontalFlip(p=0.3),
        A.VerticalFlip(p=0.3),
        A.RandomBrightnessContrast(p=0.1),
        A.ColorJitter(p=0.1),
        ToTensorV2()
      ], bbox_params=A.BboxParams(format='coco'))
  else :
      transforms = A.Compose([
          A.Resize(600,600),
          ToTensorV2()
      ], bbox_params=A.BboxParams(format='coco'))
  return transforms
def collate_fn(batch):
  return tuple(zip(*batch))
  • 데이터 셋 관련 함수
class AquariumDetection(datasets.VisionDataset):
  def __init__(self, root, split='train', transform=None, target_transform=None, transforms=None):
    super().__init__(root, transforms, transform, target_transform)
    self.split = split
    self.coco = COCO(os.path.join(root, split, "_annotations.coco.json"))
    self.ids = list(sorted(self.coco.imgs.keys()))
    self.ids = [id for id in self.ids if (len(self._load_target(id)) > 0)]

  def _load_image(self, id:int):
    path = self.coco.loadImgs(id)[0]['file_name']
    image = cv2.imread(os.path.join(self.root, self.split, path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

  def _load_target(self, id):
    return self.coco.loadAnns(self.coco.getAnnIds(id))

  def __getitem__(self, index):
    id = self.ids[index]
    image = self._load_image(id)
    target = self._load_target(id)
    # target = copy.deepcopy(self._load_target(id))

    boxes = [t['bbox'] + [t['category_id']] for t in target]

    if self.transforms is not None :
      transformed = self.transforms(image=image, bboxes = boxes)

    image = transformed['image']
    boxes = transformed['bboxes']

    new_boxes = []

    for box in boxes:
      xmin = box[0]
      xmax = xmin + box[2]
      ymin = box[1]
      ymax = ymin + box[3]
      new_boxes.append([xmin, ymin, xmax, ymax])

    boxes = torch.tensor(new_boxes, dtype=torch.float32)

    targ = {}
    targ['boxes'] = boxes
    targ['labels'] = torch.tensor([t['category_id'] for t in target], dtype = torch.int64)
    targ['image_id'] = torch.tensor([t['image_id'] for t in target])
    targ['area'] = (boxes[:,3] - boxes[:,1])*(boxes[:,2] - boxes[:,0])
    targ['iscrowd'] = torch.tensor([t['iscrowd'] for t in target], dtype=torch.int64)

    return image.div(255), targ

  def __len__(self):
    return len(self.ids)
  • 데이터 전처리 및 확인
dataset_path = '/content/Aquarium'

coco = COCO(os.path.join(dataset_path, 'train', '_annotations.coco.json'))
categories = coco.cats
n_classes = len(categories.keys())
categories

classes = [i[1]['name'] for i in categories.items()]
classes

train_dataset = AquariumDetection(root=dataset_path, transforms = get_transforms(True))
sample = train_dataset[2]
img_int = torch.tensor(sample[0] * 255, dtype=torch.uint8)
plt.imshow(draw_bounding_boxes(
    img_int, sample[1]['boxes'], [classes[i] for i in sample[1]['labels']], width=4
). permute(1,2,0))

  • 모델 준비
model = models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.boxpredictor = models.detection.faster_rcnn.FastRCNNPredictor(in_features, n_classes)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn = collate_fn)
images, targets = next(iter(train_loader))
images = list(image for image in images)
targets = [{k:v for k, v in t.items()} for t in targets]
output = model(images, targets)
device = torch.device('cuda')
model = model.to(device)
device = torch.device('cuda')
model = model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, nesterov=True, weight_decay=1e-4)
def train_one_epoch(model, optimizer, loader, device, epoch):
  model.to(device)
  model.train()

  all_losses = []
  all_losses_dict = []

  for images, targets in tqdm(loader):
    images = list(image.to(device) for image in images)
    targets = [{k:torch.tensor(v).to(device) for k, v in t.items()} for t in targets]

    loss_dict = model(images, targets)
    losses = sum(loss for loss in loss_dict.values())
    loss_dict_append = {k:v.item() for k, v in loss_dict.items()}
    loss_value = losses.item()

    all_losses.append(loss_value)
    all_losses_dict.append(loss_dict_append)

    if not math.isfinite(loss_value):
      print(f"Loss is {loss_value}, stopping training")
      print(loss_dict)
      sys.exit(1)

    optimizer.zero_grad()
    losses.backward()
    optimizer.step()

  all_losses_dict = pd.DataFrame(all_losses_dict)

  print("Epoch {}, lr: {:.6f}, loss_classifier: {:.6f}, loss_box: {:.6f}, loss_rpn_box: {:.6f}, loss_object: {:.6f}".format(
      epoch, optimizer.param_groups[0]['lr'], np.mean(all_losses),
      all_losses_dict['loss_classifier'].mean(),
      all_losses_dict['loss_box_reg'].mean(),
      all_losses_dict['loss_rpn_box_reg'].mean(),
      all_losses_dict['loss_objectness'].mean()
  ))

  writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
  writer.add_scalar('loss', np.mean(all_losses), epoch)
  writer.add_scalar('loss_classifier', all_losses_dict['loss_classifier'].mean(), epoch)
  writer.add_scalar('loss_box_reg', all_losses_dict['loss_box_reg'].mean(), epoch)
  writer.add_scalar('loss_rpn_box_reg', all_losses_dict['loss_rpn_box_reg'].mean(), epoch)
  writer.add_scalar('loss_objectness', all_losses_dict['loss_objectness'].mean(), epoch)
num_epochs = 10

writer = SummaryWriter()

for epoch in range(num_epochs):
  train_one_epoch(model, optimizer, train_loader, device, epoch)

writer.flush()
writer.close()
  • 학습 결과 이미지로 확인
model.eval()
torch.cuda.empty_cache()
test_dataset = AquariumDetection(root=dataset_path, split='test', transforms=get_transforms(False))
img, _ = test_dataset[46]
img_int = torch.tensor(img*256, dtype=torch.uint8)
with torch.no_grad():
  prediction = model([img.to(device)])
  pred = prediction[0]
print(pred) 

fig = plt.figure(figsize = (14, 10))

filtered_label = [classes[i] for i in pred['labels'][pred['scores'] > 0.8]]
filtered_score = [float(score) for score in pred['scores'][pred['scores']>0.8]]

label_text = [lbl + str(round(scr,2)) for lbl, scr in zip(filtered_label, filtered_score)]

plt.imshow(draw_bounding_boxes(img_int,
                               pred['boxes'][pred['scores'] > 0.8],
                               label_text, width=4
                               ).permute(1, 2, 0))

이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다

0개의 댓글