CNN과 Attention을 조합해 이미지 분류기를 만들 때, Attention Layer의 출력 채널 수를 클래스 개수로 설정해도 되는지 헷갈릴 수 있습니다. 이 글에서는 그 이유와 정석적인 해결 방법, 그리고 ViT(Visual Transformer) 스타일 구조 예시까지 정리해드립니다.
Input image
↓
CNN backbone (예: ResNet, Conv blocks)
↓
Flatten or Global Average Pooling
↓
Attention layer (예: Self-Attention, Multi-Head Attention)
↓
MLP or Fully Connected Layer → Output dimension: num_classes
↓
Softmax (optional, for inference)
import torch
import torch.nn as nn
class CNNWithAttentionClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.flatten = nn.AdaptiveAvgPool2d((1, 1)) # (B, C, 1, 1)
self.attn = nn.MultiheadAttention(embed_dim=64, num_heads=4, batch_first=True)
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.cnn(x) # (B, 64, H, W)
x = self.flatten(x).squeeze(-1).transpose(1, 2) # (B, 1, 64)
attn_out, _ = self.attn(x, x, x) # (B, 1, 64)
out = self.fc(attn_out.squeeze(1)) # (B, num_classes)
return out
ViT는 이미지를 patch 단위로 쪼개고, 이를 토큰처럼 다루어 Transformer에 입력합니다.
class ViTClassifier(nn.Module):
def __init__(self, image_size=224, patch_size=16, in_channels=3, num_classes=10, hidden_dim=768, num_heads=12, num_layers=6):
super().__init__()
assert image_size % patch_size == 0, "Image must be divisible by patch size"
num_patches = (image_size // patch_size) ** 2
patch_dim = in_channels * patch_size * patch_size
self.patch_embed = nn.Linear(patch_dim, hidden_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
B, C, H, W = x.shape
patch_size = int((x[0].numel() // C) ** 0.5)
x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
x = x.contiguous().view(B, C, -1, patch_size, patch_size)
x = x.permute(0, 2, 1, 3, 4).flatten(3) # (B, num_patches, patch_dim)
x = self.patch_embed(x) # (B, num_patches, hidden_dim)
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.pos_embed
x = self.transformer(x)
cls_output = x[:, 0] # CLS token output
out = self.fc(cls_output)
return out