[CLIP] CLIP CODE review 3

김보현·2024년 7월 23일
0

ComputerVision

목록 보기
7/11

Building features

특징 추출 및 유사도 계산

  1. 이미지 집합 확인

    batch = np.stack(images)
    print(f"dimension : {batch.ndim}, shape: {batch.shape}")
    • np.stack(images): images 리스트를 쌓아서 4차원 배열을 만든다.
    • print(f"dimension : {batch.ndim}, shape: {batch.shape}"): 배열의 차원 수와 형태를 출력한다.
  2. 텍스트 집합 확인

    sentence = ["This is " + desc for desc in texts]
    print(sentence[0])
    sen_to_tok = clip.tokenize(sentence)
    print(sen_to_tok[0])
    print(sen_to_tok.ndim) # 2
    print(sen_to_tok.shape) # (8, 77)
    • sentence 리스트: 각 설명 앞에 "This is "를 추가하여 만든다.
    • clip.tokenize(sentence): 텍스트를 토큰화한다.
    • print(sen_to_tok[0]): 첫 번째 문장의 토큰을 출력한다.
    • print(sen_to_tok.ndim)print(sen_to_tok.shape): 토큰 배열의 차원 수와 형태를 출력한다.
  3. 이미지와 텍스트 입력 설정

    image_input = torch.tensor(np.stack(images)).cuda()
    text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
    • 이미지를 텐서로 변환하고 GPU로 이동시킨다.
    • 텍스트를 토큰화하고 GPU로 이동시킨다.
  4. 모델의 전방 전달(Forward Pass)

    with torch.no_grad():
        image_features = model.encode_image(image_input).float()
        text_features = model.encode_text(text_tokens).float()
    • torch.no_grad() 블록 안에서 그래디언트를 계산하지 않도록 한다.
    • model.encode_image(image_input)은 이미지 특징을 추출한다.
    • model.encode_text(text_tokens)은 텍스트 특징을 추출한다.
  5. 유사도 계산

    print(image_features.shape) # [8, 512]
    print(text_features.shape) # [8, 512]
    • 추출된 이미지와 텍스트 특징의 형태를 출력한다.

Calculating cosine similarity

  1. 특징 정규화 및 코사인 유사도 계산

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
    print(similarity)
    • image_featurestext_features를 정규화한다.
    • 코사인 유사도를 계산하여 similarity 행렬을 만든다.
    • 유사도 행렬을 출력한다.
  2. 유사도 행렬의 형태 확인

    print(similarity.shape) # 8 x 8
    • 유사도 행렬의 형태를 출력한다.
  3. 유사도 행렬 시각화

    count = len(descriptions)
    
    plt.figure(figsize=(20, 14))
    plt.imshow(similarity, vmin=0.1, vmax=0.3)
    plt.yticks(range(count), texts, fontsize=18)
    plt.xticks([])
    
    for i, image in enumerate(original_images):
        plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
    
    for x in range(similarity.shape[1]):
        for y in range(similarity.shape[0]):
            plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
    
    for side in ["left", "top", "right", "bottom"]:
        plt.gca().spines[side].set_visible(False)
    
    plt.xlim([-0.5, count - 0.5])
    plt.ylim([count + 0.5, -2])
    
    plt.title("Cosine similarity between text and image features", size=20)
    • 유사도 행렬을 시각화한다.
    • 유사도를 이미지와 텍스트 설명과 함께 표시한다.

Zero-Shot Image Classification

Zero-Shot 이미지 분류

  1. CIFAR-100 데이터셋 로드

    from torchvision.datasets import CIFAR100
    
    cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
    • CIFAR-100 데이터셋을 로드하고 전처리를 적용한다.
  2. 텍스트 설명 생성 및 토큰화

    text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
    text_tokens = clip.tokenize(text_descriptions).cuda()
    • CIFAR-100 클래스에 대한 텍스트 설명을 생성하고 토큰화한다.
  3. 텍스트 특징 추출 및 정규화

    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)
    • 텍스트 특징을 추출하고 정규화한다.
  4. 코사인 유사도를 기반으로 분류

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
    • 이미지 특징과 텍스트 특징 사이의 유사도를 계산하여 확률로 변환한다.
    • 가장 높은 5개의 확률과 그에 해당하는 라벨을 추출한다.
  5. 결과 시각화

    plt.figure(figsize=(16, 16))
    
    for i, image in enumerate(original_images):
        plt.subplot(4, 4, 2 * i + 1)
        plt.imshow(image)
        plt.axis("off")
    
        plt.subplot(4, 4, 2 * i + 2)
        y = np.arange(top_probs.shape[-1])
        plt.grid()
        plt.barh(y, top_probs[i])
        plt.gca().invert_yaxis()
        plt.gca().set_axisbelow(True)
        plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
        plt.xlabel("probability")
    
    plt.subplots_adjust(wspace=0.5)
    plt.show()
    • 원본 이미지를 표시하고, 각 이미지에 대한 예측 확률을 막대 그래프로 시각화한다.
    • plt.subplots_adjust(wspace=0.5): 서브플롯 간의 간격을 조정한다.
profile
Fall in love with Computer Vision

0개의 댓글