[CLIP] CLIP CODE review 2

김보현·2024년 7월 23일
0

ComputerVision

목록 보기
6/11

Loading the model

CLIP 모델 나열 & 로드

import clip

clip 라이브러리를 임포트한다.

CLIP은 OpenAI에서 개발했다.
이미지와 텍스트를 함께 이해하는 모델이다.

사용 가능한 CLIP 모델 나열

clip.available_models()

clip.available_models()는 사용 가능한 CLIP 모델의 이름을 리스트로 반환한다.

반환된 모델 목록:

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']

각 항목은 CLIP 모델의 이름을 나타낸다.
RN50는 ResNet-50 기반 모델을 의미한다.
ViT-B/32는 Vision Transformer 기반 모델을 의미한다.

CLIP 모델 로드

model, preprocess = clip.load("ViT-B/32")

clip.load("ViT-B/32")ViT-B/32 모델을 로드하고 모델과 전처리 함수를 반환한다.
model 변수에는 모델 객체가 저장된다.
preprocess 변수에는 입력 데이터를 모델에 맞게 전처리하는 함수가 저장된다.

모델을 GPU로 이동하고 평가 모드로 설정

model.cuda().eval()

model.cuda()는 모델을 GPU로 이동시킨다.
model.eval()은 모델을 평가 모드로 설정한다. 평가모드란 모델을 학습이 아닌 추론 모드로 사용하겠다는 의미다.

모델 주요 속성 확인

input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

model.visual.input_resolution: 모델의 시각적 입력 해상도를 반환한다.
model.context_length: 텍스트 입력의 최대 길이를 반환한다.
model.vocab_size는 모델의 어휘 크기를 반환한다.

모델 파라미터 수 & 주요 속성 출력

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")

model.parameters(): 모델의 모든 파라미터를 반환한다.
각 파라미터의 shape을 곱해서 전체 파라미터 수를 계산하고, 이를 모두 더한 값을 출력한다.
np.prod(p.shape): 파라미터의 shape의 모든 값을 곱하여 파라미터의 수를 계산한다.
np.sum(...): 모든 파라미터 수를 더한다.
f"{...:,}"는 숫자를 쉼표로 구분된 형식으로 출력한다.

print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

모델의 입력 해상도, 텍스트 문맥 길이, 어휘 크기를 출력한다.

예시 출력

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408

이 출력은 로드한 ViT-B/32 모델의 파라미터 수가 151,277,313개, 입력 해상도가 224x224, 문맥 길이가 77, 어휘 크기가 49,408임을 나타낸다.

Image Preprocessing

이미지 전처리

  1. CLIP 라이브러리와 이미지 전처리 객체 로드

    preprocess
    • preprocess 객체는 clip.load() 함수의 두 번째 반환 값으로, 이미지를 모델이 기대하는 형태로 전처리하기 위해 사용하는 torchvision.transforms 객체이다.
  2. 전처리 단계 나열

    Compose(
        Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
        CenterCrop(size=(224, 224))
        <function _convert_image_to_rgb at 0x79e73f943250>
        ToTensor()
        Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
    )
    • Resize(size=224, interpolation=bicubic, max_size=None, antialias=True): 이미지를 224x224 크기로 리사이즈한다. bicubic 보간법을 사용합니다.
    • CenterCrop(size=(224, 224)): 이미지의 중심을 기준으로 224x224 크기로 자른다.
    • <function _convert_image_to_rgb at 0x79e73f943250>: 이미지를 RGB 형식으로 변환한다.
    • ToTensor(): 이미지를 텐서 형식으로 변환한다.
    • Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)): 이미지 픽셀 값을 주어진 평균 및 표준편차로 정규화한다.

Text Preprocessing

텍스트 전처리

  1. 텍스트 토크나이징

    clip.tokenize("Hello World!")
    • clip.tokenize()는 문자열을 토큰으로 변환하는 함수이다. CLIP 모델이 기대하는 입력 형식으로 변환한다.
  2. 토큰화된 출력 확인

    tensor([[49406,  3306,  1002,   256, 49407,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], dtype=torch.int32)
    • "Hello World!"라는 문장이 토큰으로 변환되고, 패딩이 추가되어 77개의 토큰 길이를 맞춘다.

Setting up input images and texts

입력 이미지 및 텍스트 설정

  1. 필요한 라이브러리 임포트

    import os
    import skimage
    import IPython.display
    import matplotlib.pyplot as plt
    from PIL import Image
    import numpy as np
    
    from collections import OrderedDict
    import torch
    
    %matplotlib inline
    %config InlineBackend.figure_format = 'retina'
    • 필요한 라이브러리를 임포트한다.
      skimagePIL: 이미지 처리
      matplotlib: 이미지 시각화
      torch: 텐서 연산
  2. 이미지와 텍스트 설명 설정

    descriptions = {
        "page": "a page of text about segmentation",
        "chelsea": "a facial photo of a tabby cat",
        "astronaut": "a portrait of an astronaut with the American flag",
        "rocket": "a rocket standing on a launchpad",
        "motorcycle_right": "a red motorcycle standing in a garage",
        "camera": "a person looking at a camera on a tripod",
        "horse": "a black-and-white silhouette of a horse",
        "coffee": "a cup of coffee on a saucer"
    }
    • descriptions 딕셔너리는 이미지 파일 이름과 그에 대응하는 텍스트 설명을 저장한다.

데이터 로드 및 시각화

  1. skimage 데이터 확인 및 필터링

    original_images = [] # skimage의 원본 image
    images = [] # preprocessed image가 들어감
    texts = [] # description(prompt)가 들어감, 만약 image의 이름이 coffee면 "a cup of coffee on a saucer"
    plt.figure(figsize=(16, 5))
    
    for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
        name = os.path.splitext(filename)[0]
        if name not in descriptions:
            continue
    
        image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
    
        plt.subplot(2, 4, len(images) + 1)
        plt.imshow(image)
        plt.title(f"{filename}\n{descriptions[name]}")
        plt.xticks([])
        plt.yticks([])
    
        original_images.append(image)
        images.append(preprocess(image))
        texts.append(descriptions[name])
    
    plt.tight_layout()
    • original_images, images, texts 리스트를 초기화한다.
    • plt.figure(figsize=(16, 5))는 16x5 크기의 시각화 영역을 설정한다.
    • for 루프는 skimage.data_dir 디렉터리에서 .png 또는 .jpg 확장자를 가진 파일들을 반복한다.
      • 파일 이름에서 확장자를 제거한 이름을 name 변수에 저장한다.
      • namedescriptions 딕셔너리에 없는 경우 continue로 다음 파일로 넘어간다.
      • 이미지를 RGB 형식으로 열고 image 변수에 저장한다.
      • plt.subplot(2, 4, len(images) + 1)는 2행 4열의 서브플롯 중 현재 이미지 인덱스에 맞는 위치를 선택한다.
      • 이미지를 표시하고 제목으로 파일 이름과 설명을 추가한다.
      • xticksyticks를 빈 리스트로 설정하여 축 눈금을 제거한다.
      • 원본 이미지를 original_images 리스트에 추가한다.
        전처리된 이미지는 images 리스트에 추가한다.
      • 설명을 texts 리스트에 추가한다.
    • plt.tight_layout(): 플롯의 레이아웃을 자동으로 조정하여 겹치지 않게 만든다.
profile
Fall in love with Computer Vision

0개의 댓글