Swinv2-tiny 모델을 통한 이미지 Embedding

곤돌이·2024년 8월 4일

Image Embedding Search

목록 보기
4/5

지난번에 EfficientNetV2B1 모델을 통해 FC Layer 이전의 Pooling 에서의 1280 차원의 텐서를 Embedding 값으로 하여 VectorDB 에 저장하고, 새로운 이미지를 통해 유사한 이미지를 추출하는 실험을 해보았습니다.

다만, 기대했던 결과가 나오지 않아 이번엔 좀 더 이미지 추출을 잘하는 microsoft/swinv2-tiny-patch4-window8-256 모델을 기반으로 동일하게 FC 이전의 768 차원의 텐서를 Embedding 값으로 VectorDB에 저장하고 같은 테스트를 해보도록 하겠습니다. (tiny모델이라 성능이 안높을 수도 있을까 🥲)


모델 빌드하기

Huggingface에서 모델을 다운받고, Pooling Layer에서 출력값을 내도록 모델을 설정 했습니다.

import os
import numpy as np
import tensorflow as tf
import psycopg2
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T

from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

BATCH_SIZE = 32
INPUT_SHAPE = (448, 448)

processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
base_model = AutoModelForImageClassification.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
model = base_model.swinv2


class ImageNetDataset(Dataset):
    def __init__(self, path):
        self.files = glob(os.path.join(path, '*.jpeg'))
        self.processor = processor
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        image = read_image(self.files[idx], torchvision.io.image.ImageReadMode.RGB)
        path = os.path.basename(self.files[idx])
        
        processed_image = self.processor(image, return_tensors='pt')['pixel_values'].squeeze(0)
        
        return processed_image, path


ds = ImageNetDataset('D:/data/imagenet1k/Data/CLS-LOC/val')
data_loader = DataLoader(ds, batch_size=32, shuffle=False)

model.eval()
model.to('cuda')

Embedding 값을 출력하고 VectorDB에 저장하기

지난번과 마찬가지로 모델의 Forward된 결과(Embedding 값)을 pgvector 에 저장을 했습니다.
개인적으로 간단한 코드의 경우 Pytorch 보단 Tensorflow가 성능이 좋은 것 같습니다. 😀

conn = psycopg2.connect('postgresql://namkon:${SECRET}@${IP_ADDR}:5432/test')
cursor = conn.cursor()
sql = "insert into swinv2_tiny (filename, embedding) values ('{filename}', '{embedding}')"

with torch.no_grad():
    for images, paths in data_loader:
        embeddings = model(images.to('cuda'))
        for embedding, path in zip(embeddings.pooler_output, paths):
            cursor.execute(sql.format(filename=path, embedding=embedding.tolist()))
    
conn.commit()

결과 테스트

지난번과 마찬가지로 위 Embedding 에 사용하지 않은 이미지를 기반으로 가장 근접한 이미지들을 뽑아보고 이미지 확인을 해보았습니다.

test_path = 'D:/data/imagenet1k/Data/CLS-LOC/test/ILSVRC2012_test_00001580.jpeg'
test_image = read_image(test_path, torchvision.io.image.ImageReadMode.RGB)
test_image = processor(test_image, return_tensors='pt')['pixel_values'].squeeze(0)
test_image = test_image.unsqueeze(0)

embedding = model(test_image.to('cuda'))

%%time
sql = f"select * from swinv2_tiny order by embedding <-> '{embedding.pooler_output.tolist()[0]}' limit 3"
cursor.execute(sql)
ret = cursor.fetchall()
# CPU times: total: 0 ns
# Wall time: 141 ms

base_path = 'D:/data/imagenet1k/Data/CLS-LOC/val'

plt.subplots(figsize=(20,12))

plt.subplot(1,4,1)
plt.imshow(test_image[0].permute(1, 2, 0))
plt.axis('off')
plt.title('Test Image')

plt.subplot(1,4,2)
plt.imshow(Image.open(os.path.join(base_path, ret[0][1])))
plt.axis('off')
plt.title('1st close image')

plt.subplot(1,4,3)
plt.imshow(Image.open(os.path.join(base_path, ret[1][1])))
plt.axis('off')
plt.title('2nd closest image')

plt.subplot(1,4,4)
plt.imshow(Image.open(os.path.join(base_path, ret[2][1])))
plt.axis('off')
plt.title('3rd closest image')

plt.show()


Conclusion

microsoft/swinv2-tiny-patch4-window8-256 을 통해 Embedding 실험을 해보았는데, 지난번 했던 EfficientNetV2B1 모델 대비 Embeddign 차원 감소(1280 -> 768)과 검색속도 2배 상승 (297ms -> 141ms) 그리고 가장 중요한 성능또한 제대로 나옴을 확인했습니다.
실제 업무에서도 잘 사용할 수 있을 것 같군요!!

Swin Transformer 사랑합니다 😘
다음번엔 CLIP 모델을 이용하여 텍스트 검색을 하여 원하는 이미지를 찾는 것을 해보도록 하겠습니다.

profile
Data Scientist

0개의 댓글