딥러닝 실전 팁: Attention 출력과 클래스 수의 관계 (ViT 예시 포함)

Bean·2025년 5월 25일
0

인공지능

목록 보기
40/123

[딥러닝 정리] CNN과 Attention으로 Image Classification 할 때 Output은 어떻게 설정할까? (+ ViT 구조 예시)

CNN과 Attention을 조합해 이미지 분류기를 만들 때, Attention Layer의 출력 채널 수를 클래스 개수로 설정해도 되는지 헷갈릴 수 있습니다. 이 글에서는 그 이유정석적인 해결 방법, 그리고 ViT(Visual Transformer) 스타일 구조 예시까지 정리해드립니다.


1. Attention Layer의 Output을 클래스 수로 직접 설정하지 않는 이유

  • Attention Layer는 이미지의 관계 정보에 집중해서 **더 나은 표현(feature representation)**을 만드는 역할을 합니다.
  • 이 출력은 보통 **중간 차원의 feature vector (예: hidden dimension 256, 512 등)**이며,
    최종 분류는 Linear (Fully Connected) Layer가 담당합니다.
  • 따라서 Attention Layer의 output을 클래스 개수만큼 설정하지 않고,
    Linear Layer에서 output dimension을 num_classes로 줄이는 방식이 일반적입니다.

2. 기본 구조: CNN + Attention + Classifier

Input image  
   ↓  
CNN backbone (예: ResNet, Conv blocks)  
   ↓  
Flatten or Global Average Pooling  
   ↓  
Attention layer (예: Self-Attention, Multi-Head Attention)  
   ↓  
MLP or Fully Connected Layer → Output dimension: num_classes  
   ↓  
Softmax (optional, for inference)
  • CNN → 로우 레벨 특징 추출
  • Attention → 관계 기반 정보 강화
  • FC Layer → 최종 클래스 분류

3. PyTorch 코드 예시: CNN + Attention

import torch
import torch.nn as nn

class CNNWithAttentionClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.flatten = nn.AdaptiveAvgPool2d((1, 1))  # (B, C, 1, 1)
        self.attn = nn.MultiheadAttention(embed_dim=64, num_heads=4, batch_first=True)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.cnn(x)  # (B, 64, H, W)
        x = self.flatten(x).squeeze(-1).transpose(1, 2)  # (B, 1, 64)
        attn_out, _ = self.attn(x, x, x)  # (B, 1, 64)
        out = self.fc(attn_out.squeeze(1))  # (B, num_classes)
        return out

4. ViT 구조 기반 구현 예시 (Visual Transformer)

ViT는 이미지를 patch 단위로 쪼개고, 이를 토큰처럼 다루어 Transformer에 입력합니다.

class ViTClassifier(nn.Module):
    def __init__(self, image_size=224, patch_size=16, in_channels=3, num_classes=10, hidden_dim=768, num_heads=12, num_layers=6):
        super().__init__()
        assert image_size % patch_size == 0, "Image must be divisible by patch size"
        num_patches = (image_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        self.patch_embed = nn.Linear(patch_dim, hidden_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        B, C, H, W = x.shape
        patch_size = int((x[0].numel() // C) ** 0.5)
        x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
        x = x.contiguous().view(B, C, -1, patch_size, patch_size)
        x = x.permute(0, 2, 1, 3, 4).flatten(3)  # (B, num_patches, patch_dim)
        x = self.patch_embed(x)  # (B, num_patches, hidden_dim)

        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embed

        x = self.transformer(x)
        cls_output = x[:, 0]  # CLS token output
        out = self.fc(cls_output)
        return out

5. 요약

  • Attention Layer는 feature extractor입니다. 직접 클래스 수로 출력하지 않고, Linear Layer를 통해 분류합니다.
  • CNN + Attention 조합은 간단하고 강력한 구조입니다.
  • ViT 스타일은 patch embedding + transformer encoder로 구성되며, CLS token의 출력만 사용해 분류합니다.

profile
AI developer

0개의 댓글