[AI] Object Detection 실습

Bora Kwon·2022년 6월 29일
0

custom_data.py : 학습 코드 및 데이터 로드 코드

import torch.nn as nn
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import cv2
import numpy as np
import torch
import torchvision
# import albumentations as albumentations
import time

"""프로세스 상태정보 보여 주는것과 동일"""
from tqdm import tqdm
from math import gamma
from bs4 import BeautifulSoup
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
# from albumentations.pytorch import ToTensorV2


def generate_box(obj):
    """
    <xmin>79</xmin>
    <ymin>105</ymin>
    <xmax>109</xmax>
    <ymax>142</ymax>
    """
    xmin = float(obj.find("xmin").text)
    ymin = float(obj.find('ymin').text)
    xmax = float(obj.find('xmax').text)
    ymax = float(obj.find('ymax').text)

    return [xmin, ymin, xmax, ymax]


def generate_label(obj):
    # <name>with_mask</name>
    """label info -> mask_weared_incorrect -> 2, with_mask -> 1, without_mask -> 0 """

    if obj.find("name").text == 'with_mask':
        return 1
    elif obj.find("name").text == 'mask_weared_incorrect':
        return 2

    return 0


def generate_target(file):
    with open(file) as f:
        data = f.read()
        soup = BeautifulSoup(data, "html.parser")
        objects = soup.find_all("object")
        boxes = []
        labels = []
        for i in objects:
            boxes.append(generate_box(i))
            labels.append(generate_label(i))

        """fix code """
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels

        return target


def plot_image_from_output(img, annotation):

    # img = mping.imread(img_path)

    # 텐서 이미지 -> 이미지 화 처리
    img = img.cpu().permute(1, 2, 0)

    """img show"""
    # fig, ax = plt.subplots(1)
    # ax.imshow(img)

    rects = []

    for idx in range(len(annotation['boxes'])):
        xmin, ymin, xmax, ymax = annotation["boxes"][idx]

        if annotation['labels'][idx] == 0:
            rect = patches.Rectangle(
                (xmin, ymin), (xmax - xmin), (ymax - ymin), linewidth=1, edgecolor='r', facecolor='none'
            )
        elif annotation['labels'][idx] == 1:
            rect = patches.Rectangle(
                (xmin, ymin), (xmax - xmin), (ymax - ymin), linewidth=1, edgecolor='g', facecolor='none'
            )
        else:
            rect = patches.Rectangle(
                (xmin, ymin), (xmax - xmin), (ymax - ymin), linewidth=1, edgecolor='b', facecolor='none'
            )
        rects.append(rect)
    """image show"""
    #     ax.add_patch(rect)

    # plt.show()

    return img, rects


class MaskDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.img = list(sorted(os.listdir(self.path)))
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, index):
        file_image = self.img[index]

        file_label = self.img[index][:-3] + 'xml'
        img_path = os.path.join(self.path, file_image)

        # print("file_image ", file_image)
        # print("file_label ", file_label)
        # print("img_path ", img_path)

        if 'test' in self.path:
            label_path = os.path.join("./test_annotations/", file_label)

        else:
            label_path = os.path.join("./annotations/", file_label)

        """cv2 image read"""
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        target = generate_target(label_path)

        to_tensor = torchvision.transforms.ToTensor()

        if self.transform:
            transformed = self.transform(
                image=image, bboxes=target['boxes'], labels=target['labels'])
            image = transformed['image']
            # target = {'boxes': transformed['bboxes'],
            #           'labels': transformed['labels']}
        else:
            image = to_tensor(image)
        """image -> tensor"""

        return image, target


def collate_fn(batch):
    return tuple(zip(*batch))


# bbox_transform = albumentations.Compose([
#     albumentations.HorizontalFlip(),
#     albumentations.Rotate(p=0.8),
#     ToTensorV2()
# ], bbox_params=albumentations.BboxParams(
#     format='pascal_voc', label_fields=['labels']
# ))
# bbox_transform_test = albumentations.Compose([
#     ToTensorV2()
# ], bbox_params=albumentations.BboxParams(
#     format='pascal_voc', label_fields=['labels']
# ))


train_dataset = MaskDataset("./images/")
test_dataset = MaskDataset("./test_images/")

train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, collate_fn=collate_fn)


"""모델 호출"""
retina = torchvision.models.detection.retinanet_resnet50_fpn(
    num_classes=3, pretrained=False, pretrained_backbone=True
)

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

num_epochs = 30
retina.to(device)

"""gradinet calculation 이 필요한 params 만 추출"""
params = [p for p in retina.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, lr=0.0025, momentum=0.9, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=10, gamma=0.1)

len_dataloader = len(train_loader)

"""train loop"""
for epoch in range(num_epochs):
    start = time.time()
    retina.train()

    i = 0
    epoch_loss = 0
    for index, (images, targets) in enumerate(train_loader):
        images = list(image.to(device) for image in images)

        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = retina(images, targets)
        optimizer.zero_grad()
        losses = sum(loss for loss in outputs.values())

        i += 1

        losses.backward()
        optimizer.step()

        if index % 10 == 0:
            print("loss >>", losses.item(), "epoch >> ", epoch, "index >> ", index,
                  f"time : {time.time() - start}")

            torch.save(retina.state_dict(),
                       f"./retina_{num_epochs}_{epoch}.pt")

    if lr_scheduler is not None:
        lr_scheduler.step()
torch.save(retina.state_dict(), f"./retina_last.pt")

data_split.py : 학습데이터 train/test 데이터로 나눔

import os
import random
import numpy as np
import shutil


"""폴더구성"""
os.makedirs("./test_images", exist_ok=True)
os.makedirs("./test_annotations", exist_ok=True)

print(len(os.listdir("./annotations")))
print(len(os.listdir("./images")))

random.seed(7777)
idx = random.sample(range(853), 170)

for img in np.array(sorted(os.listdir("./images")))[idx]:
    print("img info ", img)
    shutil.move("./images/" + img, "./test_images/"+img)

for anno in np.array(sorted(os.listdir("./annotations")))[idx]:
    print("annotation info ", anno)
    shutil.move("./annotations/" + anno, "./test_annotations/" + anno)

print("info file size \n")
print(len(os.listdir("./images")))
print(len(os.listdir("./annotations")))
print(len(os.listdir("./test_images")))
print(len(os.listdir("./test_annotations")))

data_call.py : 학습에 필요한 데이터를 다운로드

from zipfile import ZipFile
import gdown
import argparse

"""pip install gdown"""
"""사용법 : python3 data_call.py --data FaceMaskDetection"""

file_destinations = {
    'FaceMaskDetection': 'Face Mask Detection.zip', }
file_id_dic = {
    'FaceMaskDetection': '1pJtohTc9NGNRzHj5IsySR39JIRPfkgD3'
}


def download_file_from_google_drive(id_, destination):
    url = f"https://drive.google.com/uc?id={id_}"
    output = destination
    gdown.download(url, output, quiet=True)
    print(f"{output} download complete")


parser = argparse.ArgumentParser(
    description='data loader ... '
)
parser.add_argument('--data', type=str, help='key for selecting data..!!')

args = parser.parse_args()


download_file_from_google_drive(
    id_=file_id_dic[args.data], destination=file_destinations[args.data]
)

"""압축 풀기"""
test_file_name = "./Face Mask Detection.zip"

with ZipFile(test_file_name, 'r') as zip:
    zip.printdir()
    zip.extractall()

# download_file_from_google_drive(
#     id=file_id_dic["FaceMaskDetection"],
#     destination=file_destinations["FaceMaskDetection"]
# )
profile
Software Developer

0개의 댓글