Yolo finetuning

MachineDeepLearning

목록 보기
4/5
post-thumbnail
!pip install ultralytics

import torch
import os
from ultralytics import YOLO  # Using Ultralytics YOLO
from torch.utils.data import DataLoader
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

class YOLODataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, label_dir, img_size=640, augment=False):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.img_size = img_size
        self.augment = augment
        self.img_files = sorted(os.listdir(img_dir))
        self.label_files = sorted(os.listdir(label_dir))

        self.base_transform = A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

        self.augment_transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
            A.GaussianBlur(p=0.2),
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        label_path = os.path.join(self.label_dir, self.label_files[idx])

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        bboxes = []
        class_labels = []
        
        try:
            with open(label_path, "r") as f:
                for line in f.readlines():
                    class_id, x_center, y_center, width, height = map(float, line.split())
                    bboxes.append([x_center, y_center, width, height])
                    class_labels.append(class_id)
        except FileNotFoundError:
            bboxes = torch.zeros((0, 4))
            class_labels = torch.zeros(0, dtype=torch.long)

        transform = self.augment_transform if self.augment else self.base_transform
        transformed = transform(image=img, bboxes=bboxes, class_labels=class_labels)

        img = transformed['image']
        bboxes = torch.tensor(transformed['bboxes']) if len(transformed['bboxes']) > 0 else torch.zeros((0, 4))
        class_labels = torch.tensor(class_labels)

        return img, bboxes, class_labels

# Dataset Paths
train_dataset = YOLODataset("/content/drive/MyDrive/train/images", "/content/drive/MyDrive/train/labels", img_size=640, augment=True)
val_dataset = YOLODataset("/content/drive/MyDrive/valid/images", "/content/drive/MyDrive/valid/labels", img_size=640, augment=False)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Load Pretrained YOLO Model
yolo_model = YOLO("yolov8s.pt")  # Change model variant if needed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolo_model.to(device)

def train_model(model, data_path, num_epochs=10):
    model.train(data=data_path, epochs=num_epochs, imgsz=416, batch=16, device="cuda")
    return model

# Fine-tune the model
fine_tuned_model = train_model(yolo_model, "/content/drive/MyDrive/data.yaml", num_epochs=10)

# Save fine-tuned model
fine_tuned_model.save("yolo_finetuned.pt")

Ground Truth

Predicition

profile
AI, Graphics, Medical Imaging

0개의 댓글