#[IQA] 3. 자가 지도 학습(self-supervised learning))_self_supervised_learning_unet

degull·2024년 9월 4일
0

모델이 이미지 데이터에서 직접 특징을 학습하도록 하는 방법
→ 이를 통해 모델은 이미지의 장면 정보와 왜곡 유형을 인식하게 되고, 이러한 정보는 품질 점수 예측에 유용하게 사용되도록 함


Autoencoder 모델 생성

이미지 복원을 위한 Vision Transformer (ViT) 기반의 U-Net Autoencoder 모델을 학습

목표 : 손상된 이미지를 입력으로 받아 이미지를 복원하는 능력을 학습
사용 이유 : 이미지의 중요한 특징을 더 잘 학습할 수 있도록 돕기 때문



import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import random
from transformers import ViTModel
from pytorch_msssim import SSIM  # SSIM 손실 함수 사용

# GPU 설정
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 시드 고정
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# 데이터셋 준비
class IQADataset(Dataset):
    def __init__(self, txt_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(txt_file, sep="\t", header=None)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image

# 이미지 전처리 및 데이터 증대 정의
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 데이터셋 경로
img_dir = 'C:/Users/IIPL02/Desktop/LIQE_nonIndex/IQA_Database/kadid10k'
txt_file = 'C:/Users/IIPL02/Desktop/LIQE_nonIndex/IQA_Database/kadid10k/splits2/kadid10k_val_clip.txt'

# 데이터셋 로드
dataset = IQADataset(txt_file=txt_file, img_dir=img_dir, transform=transform)
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=seed)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

# ViT 기반 인코더와 U-Net 디코더를 결합한 Autoencoder 모델 정의
class ViTUAutoencoder(nn.Module):
    def __init__(self):
        super(ViTUAutoencoder, self).__init__()
        # ViT 모델 로드 (인코더 역할)
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')

        # U-Net 구조의 디코더
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768, 512, kernel_size=2, stride=2),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.Conv2d(64, 3, kernel_size=1),  # RGB 채널 복원
            nn.Sigmoid()  # [0, 1] 범위로 값을 제한
        )

    def forward(self, x):
        # ViT 인코더로 특징 추출 (batch_size, 196, 768)
        vit_outputs = self.vit(pixel_values=x)
        encoded = vit_outputs.last_hidden_state[:, 1:, :]  # CLS 토큰 제외
        encoded = encoded.permute(0, 2, 1).contiguous()  # (batch_size, 768, 196)
        encoded = encoded.view(encoded.size(0), 768, 14, 14)  # (batch_size, 768, 14, 14)로 reshape

        # U-Net 디코더 통과
        decoded = self.decoder(encoded)
        return decoded

model = ViTUAutoencoder().to(device)

# 손실 함수 및 옵티마이저 정의
criterion_mse = nn.MSELoss()  # 복원 손실 함수
criterion_ssim = SSIM(data_range=1.0, size_average=True, channel=3)
optimizer = optim.AdamW(model.parameters(), lr=0.00005)  # 학습률 조정
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# 학습 루프 정의
num_epochs = 100
best_val_loss = float('inf')
early_stop_patience = 10
early_stop_counter = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images in train_loader:
        images = images.to(device)
        
        # Forward Pass
        outputs = model(images)
        
        # 복합 손실 계산
        loss_mse = criterion_mse(outputs, images)
        loss_ssim = 1 - criterion_ssim(outputs, images)  # SSIM을 사용한 추가 손실
        loss = 0.7 * loss_mse + 0.3 * loss_ssim  # SSIM의 비중을 높임
        
        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
    
    train_loss /= len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images in val_loader:
            images = images.to(device)
            outputs = model(images)
            loss_mse = criterion_mse(outputs, images)
            loss_ssim = 1 - criterion_ssim(outputs, images)
            loss = 0.7 * loss_mse + 0.3 * loss_ssim
            val_loss += loss.item() * images.size(0)
    
    val_loss /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f}")
    
    # 학습률 조정
    scheduler.step(val_loss)

    # Early Stopping 체크
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'fin_vit_unet_autoencoder_model_optimized.pth')
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= early_stop_patience:
            print("Early stopping activated.")
            break

주요 구성 요소와 학습 과정 설명:

1. 데이터셋 준비 (IQADataset 클래스)

  • IQADataset 클래스는 이미지 데이터를 불러오고, 그 이미지를 전처리하는 역할
  • 이미지 파일 경로를 읽고, 이미지를 RGB 형식으로 변환한 후, 전처리(크기 조정, 데이터 증강, 정규화) 수행

2. 이미지 전처리

  • 이미지를 224x224로 크기 조정하고, 정규화 과정을 거쳐 Vision Transformer(ViT)에서 사용할 수 있도록 변환
  • 이 전처리는 데이터 증강(RandomResizedCrop, RandomHorizontalFlip)을 포함하여 더 다양한 패턴 학습

3. 모델 정의 (ViTUAutoencoder 클래스)

  • ViT (Vision Transformer) : 이미지를 처리하고, 중요한 특징을 추출하는 역할
  • ViT의 출력으로 인코딩된 이미지 특징을 추출하며, 이는 U-Net 디코더에 의해 복원
    • ViT → 196x768 크기의 특징 맵 생성. 여기서 <196은 이미지 패치의 개수, 768은 각 패치의 차원>
    • U-Net 디코더 : 여러 단계의 업샘플링을 통해 원본 이미지와 동일한 해상도의 복원 이미지 생성

4. 손실 함수

  • MSE (Mean Squared Error)SSIM (Structural Similarity Index Measure)를 사용해 이미지 복원 성능을 평가
    • MSE 손실 : 원본 이미지와 복원된 이미지 사이의 픽셀 단위 차이 계산
    • SSIM 손실 : 두 이미지의 구조적 유사성 평가. → 이미지의 전반적인 구조적 정보를 보존하는 데 도움을 줌

핵심 기능

  • 이 모델은 주어진 손상된 이미지를 입력으로 받아 이미지를 복원하는 역할 수행
  • 모델은 픽셀 단위의 손실(MSE)이미지의 구조적 손실(SSIM)을 함께 고려하여 학습
  • 최종적으로는 입력 이미지의 중요한 특징을 학습하고 복원할 수 있도록 설계된 Autoencoder 모델

결과



github : https://github.com/degull/IQA_self_supervised.git

profile
그래도 해야지

0개의 댓글