현 시점까지 Vision task에서 SOTA는 convolution 기반의 클래식한 아키텍쳐인 ResNet-like model이다. NLP task에선 Transformer가 도입되어 성능을 잘 낼 뿐만 아니라 Pre-train→ Fine-tuning 학습 파이프라인으로 적은 양의 data로 application specific모델을 쉽게 만들수 있다. 여기서 영감을 받아 vision task에 transformer를 사용해 보고자 한다. ^~^
Mid-sized dataset인 ImageNet으로 학습하였을 때, ResNet에 비해 좋지 못한 성능을 보여 슬펐다 ㅜ~ㅜ. 그러나 CNN이 vision task에서 잘 먹히는 이유인 inductive bias = translation invariant + locality를 transformer는 가지고 있지 않다는 점을 생각하면, 같은 양의 dataset을 이용해 더 높은 성능을 내는 것은 당연히 어려운 것일지도 모른다. 그래서, larger-sized dataset인 14M-300M images를 이용해 학습했을 때 비로소 ViT는 SOTA를 이긴다.
Model Architecture overview
2D image를 transformer input으로 사용하기 위해 Patch partitioning을 진행한다. original image의 resolution (H,W) → (P,P) * N개로 (where N = HW/P^2) 단순히 잘라주면 된다. 이후 Flatten하여 trainable linear projection layer를 통과시킨다. 이렇게 1차원 embedding이 만들어지면 cls token을 concat하고 positional embedding을 더해 준다. class token은 이미지 전체에 대한 context를 포함하고 있다고 가정되는 token으로써 이후 MLP head에 주입되어 classification의 대상이 된다.
Transformer encoder 파트에서는 Multi-head attention, MLP, Layernorm, residual connection이 매 블락 적용되고 이를 N(=12)개 stack한다.
전술하였듯, ViT의 transformer layer는 global이기 때문에 CNN based model에 비해 inductive bias가 부족하다.(translation invariace, locality) 마지막 MLP head만이 locality, trans equivalence등을 가진다.
논문 참조 ㄱㄱ
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
""" Split image intpo patches and then embed them
Parameters
----------
img_size : int
size of the image
patch_size : int
size of the patch
in_chans : int
# of input channels
embed_dim : int
embedding dimension
Attributes
----------
n_pathces : int
# of patches inside of out image
prok : nn.Conv2d
Convolutional layer that does both splitting into patches and their embedding, by setting kernel_size == stride
"""
def __init__(self, img_size, patch_size, in_chans = 3, embed_dim = 768):
super(PatchEmbed,self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size = patch_size,
stride = patch_size
)
def forward(self, x):
""" Run forward pass.
Parameters
----------
x : torch.Tensor
Shape `(n_samples, in_chans, img_size, img_size)`
Returns
-------
torch.Tensor
Shape `(n_sampls, n_patches, embed_dim)`
"""
x = self.proj(x)
x = x.flatten(2)
x = x.transpose(1,2)
return x
class Attention(nn.Module):
""" Attention mechanism
Parameters
----------
dim : int
The input and out dimension of per token features
n_heads : int
# of attention heads
qkv_bias : bool
It true, we include bias to qkv
attn_p : float
Dropout probability applied to qkv tensors
proj_p : float
Dropout probability applied to output tensor
Attributes
----------
scale : float
Nomalizing constanat
qkv : nn.Linear
Linear Projection for the query, key, and value
proj : nn.Lineat
Linear projecion that takes int the concatenated output of all attention heads and maps it into new space
attn_drop, proj_drop : nn.Dropout
Dropout layers
"""
def __init__(self, dim, n_heads=12, qkv_bias = True, attn_p=0., proj_p=0.):
super(Attention,self).__init__()
self.n_heads = n_heads
self.dim = dim
self.head_dim = dim//n_heads
self.scale = self.head_dim** -0.5
self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_p)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_p)
def forward(self, x):
""" Run forward pass
Parametes
---------
x : torch.Tensor
Shape `(n_samples, n_patches + 1, dim)`
Returns
-------
torch.Tensor
Shape `(n_samples, n_patches + 1, dim)` --> same dimensionality
"""
n_samples, n_tokens, dim = x.shape
if dim != self.dim:
raise ValueError
qkv = self.qkv(x) # (n_samples, n_patches + 1, dim*3)
qkv = qkv.reshape(n_samples, n_tokens, 3 , self.n_heads, self.head_dim)
qkv = qkv.permute(2,0,3,1,4) # (3, n_samples, n_heads, n_patches+1, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
k_t = k.transpose(-2,-1) # (n_samples, n_heads, head_dim, n_patches + 1)
dp = (q@k_t) *self.scale
attn = dp.softmax(dim=-1)
attn = self.attn_drop(attn)
weighted_avg = attn @ v # (n_samples, n_heads, n_patches+1, head_dim)
weighted_avg = weighted_avg.transpose(1,2)
weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches+1, dim)
x = self.proj(weighted_avg)
x = self.proj_drop(x)
return x
class MLP(nn.Module):
"""Multilayer perceptron
Parameters
-----------
in_features : int
# of input features
hidden_features : int
# of nodes in the hidden layer
out_features : int
# of output features
p : float
Dropout prob
Attributes
----------
fc : nn.Linear()
act : nn.GELU()
fc2 : nn.Linear()
drop : nn.Dropout()
"""
def __init__(self, in_features, hidden_features, out_features, p=0.):
super().__init__()
self.fc = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(p)
def forward(self, x):
""" Run forward pass
Parameters
----------
x : torch.Tensor
Shape `(n_samples, n_patches+1, in_features)`
Return
------
torch.Tensor
Shape `(n_samples, n_patches+1, out_features)`
"""
x = self.fc(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
"""Transformer Block
Parameters
----------
dim : int
Embedding dimension
n_heads : int
# of attention head
mlp_ratio : float
Determines the hidden dimension size of MPL module with respect to dim
qkv_bias : boolean
p, attn_p : float
Dropout prob
Attribute
----------
norm1, norm2 : LayerNorm
Layer Normalization
attn : Attention
Attention module
mlp : MLP
MLP module
"""
def __init__(self, dim, n_heads, mlp_ratio =4.0 , qkv_bias=True, p=0., attn_p=0.):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.norm1 = nn.LayerNorm(dim, eps = 1e-6)
self.norm2 = nn.LayerNorm(dim, eps = 1e-6)
self.attn = Attention(dim, n_heads=n_heads,qkv_bias=qkv_bias,attn_p=attn_p,proj_p=p)
hidden_features = int(dim*mlp_ratio)
self.mlp = MLP(in_features = dim, hidden_features=hidden_features,out_features=dim)
def forward(self, x):
""" Run forward pass
Parameters
---------
x : torch.Tensor
Shape `(n_samples, n_patches+1, dim)`
Returns
-------
torch.Tensor
Shape `(n_samples, n_patches+1, dim)`
"""
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
""" Simplified implementation of ViT
Parameters
----------
img_size : int
patch_size : int
in_chans : int
n_classes : int
embed_dim : int
Dimensionality of token/patch embeddings
depth : int
# of blocks
n_heads : int
mlp_ratio : float
qkv_bias : boolean
p, attn_p : float
Attributes
----------
patch_embed : PatchEmbed
Instance of PatchEmbed layer
cls_token : nn.Parameter
Learnable parameter that will represent the first token in the sequence
pos_emb : nn.Parameter
Positional embedding of the cls token + all the patches.
It has `(n_patches + 1)` * embed_dim` elements
pos_drop : nn.Dropout
blocks : nn.ModuleList
List of Block modules
norm : nn.LayerNorm
"""
def __init__(self, img_size=384, patch_size=16,in_chans=3, n_classes=1000,embed_dim=768,depth=12,n_heads=12,mlp_ratio=4.,qkv_bias=True,p=0.,attn_p=0.):
super().__init__()
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,in_chans=in_chans, embed_dim=embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
self.pos_emb = nn.Parameter(torch.zeros(1,1+self.patch_embed.n_patches, embed_dim))
self.pos_drop = nn.Dropout(p)
self.blocks = nn.ModuleList(
[
Block(
dim = embed_dim,
n_heads=n_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
p=p,
attn_p=attn_p
)
for _ in range(depth)
]
)
self.norm = nn.LayerNorm(embed_dim,eps=1e-6)
self.head = nn.Linear(embed_dim, n_classes) # final
def forward(self, x):
""" Run forward pass
Parameters
----------
x : torch.Tensor
Shape `(n_samples, in_chans, img_size, img_size )`
Returns
-------
logits : torch.Tensor
Logits over all classes - `(n_samples, n_classes)`
"""
n_samples = x.shape[0]
x = self.patch_embed(x)
cls_token = self.cls_token.expand(n_samples,-1,-1)
x = torch.cat((cls_token,x), dim=1)
x = x + self.pos_emb
x = self.pos_drop(x)
for block in blocks:
x = block(x)
x = self.norm(x)
cls_token_final = x[:,0]
x = self.head(cls_token_final)
return x
너무 좋은 글이네요. 공유해주셔서 감사합니다.