An Image Is Worth 16x16 Words : Transformers for Image Recognition at Scale

Sonny020402·2023년 7월 20일

Deep Learning

현 시점까지 Vision task에서 SOTA는 convolution 기반의 클래식한 아키텍쳐인 ResNet-like model이다. NLP task에선 Transformer가 도입되어 성능을 잘 낼 뿐만 아니라 Pre-train→ Fine-tuning 학습 파이프라인으로 적은 양의 data로 application specific모델을 쉽게 만들수 있다. 여기서 영감을 받아 vision task에 transformer를 사용해 보고자 한다. ^~^

Mid-sized dataset인 ImageNet으로 학습하였을 때, ResNet에 비해 좋지 못한 성능을 보여 슬펐다 ㅜ~ㅜ. 그러나 CNN이 vision task에서 잘 먹히는 이유인 inductive bias = translation invariant + locality를 transformer는 가지고 있지 않다는 점을 생각하면, 같은 양의 dataset을 이용해 더 높은 성능을 내는 것은 당연히 어려운 것일지도 모른다. 그래서, larger-sized dataset인 14M-300M images를 이용해 학습했을 때 비로소 ViT는 SOTA를 이긴다.


Model Architecture overview

2D image를 transformer input으로 사용하기 위해 Patch partitioning을 진행한다. original image의 resolution (H,W) → (P,P) * N개로 (where N = HW/P^2) 단순히 잘라주면 된다. 이후 Flatten하여 trainable linear projection layer를 통과시킨다. 이렇게 1차원 embedding이 만들어지면 cls token을 concat하고 positional embedding을 더해 준다. class token은 이미지 전체에 대한 context를 포함하고 있다고 가정되는 token으로써 이후 MLP head에 주입되어 classification의 대상이 된다.

Transformer encoder 파트에서는 Multi-head attention, MLP, Layernorm, residual connection이 매 블락 적용되고 이를 N(=12)개 stack한다.

z0=[xclass,xp1E;xp2E;...;xpNE]+Eposzl=MSA(LN(zl1))+zl1zl=MLP(LN(zl))+zly=LN(zL0)z_0 = [x_{class}, x^1_pE;x^2_pE; ...;x^N_pE] + E_{pos} \\ z^\prime_l = MSA(LN(z_{l-1}))+z_{l-1} \\ z_l = MLP(LN(z^\prime_l)) + z^\prime_l \\ y=LN(z^0_L)

Inductive bias

전술하였듯, ViT의 transformer layer는 global이기 때문에 CNN based model에 비해 inductive bias가 부족하다.(translation invariace, locality) 마지막 MLP head만이 locality, trans equivalence등을 가진다.


import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """ Split image intpo patches and then embed them

    img_size : int
        size of the image

    patch_size : int
        size of the patch

    in_chans : int 
        # of input channels

    embed_dim : int
        embedding dimension
    n_pathces : int 
        # of patches inside of out image
    prok : nn.Conv2d
        Convolutional layer that does both splitting into patches and their embedding, by setting kernel_size == stride
    def __init__(self, img_size, patch_size, in_chans = 3, embed_dim = 768):
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            kernel_size = patch_size,
            stride = patch_size

    def forward(self, x):
        """ Run forward pass.

        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`
            Shape `(n_sampls, n_patches, embed_dim)`
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1,2)
        return x

class Attention(nn.Module):
    """ Attention mechanism

    dim : int
        The input and out dimension of per token features
    n_heads : int
        # of attention heads
    qkv_bias : bool
        It true, we include bias to qkv
    attn_p : float 
        Dropout probability applied to qkv tensors
    proj_p : float 
        Dropout probability applied to output tensor

    scale : float
        Nomalizing constanat
    qkv : nn.Linear
        Linear Projection for the query, key, and value
    proj : nn.Lineat
        Linear projecion that takes int the concatenated output of all attention heads and maps it into new space
    attn_drop, proj_drop : nn.Dropout
        Dropout layers
    def __init__(self, dim, n_heads=12, qkv_bias = True, attn_p=0., proj_p=0.):
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim//n_heads
        self.scale = self.head_dim** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """ Run forward pass

        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`
            Shape `(n_samples, n_patches + 1, dim)` --> same dimensionality

        n_samples, n_tokens, dim = x.shape
        if dim != self.dim:
            raise ValueError

        qkv = self.qkv(x) # (n_samples, n_patches + 1, dim*3)
        qkv = qkv.reshape(n_samples, n_tokens, 3 , self.n_heads, self.head_dim)
        qkv = qkv.permute(2,0,3,1,4) # (3, n_samples, n_heads, n_patches+1, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k_t = k.transpose(-2,-1) # (n_samples, n_heads, head_dim, n_patches + 1)
        dp = (q@k_t) *self.scale 
        attn = dp.softmax(dim=-1)
        attn = self.attn_drop(attn)

        weighted_avg = attn @ v # (n_samples, n_heads, n_patches+1, head_dim)
        weighted_avg = weighted_avg.transpose(1,2) 
        weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches+1, dim)

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x

class MLP(nn.Module):
    """Multilayer perceptron

    in_features : int
        # of input features
    hidden_features : int
        # of nodes in the hidden layer
    out_features : int
        # of output features
    p : float
        Dropout prob
    fc : nn.Linear()
    act : nn.GELU()

    fc2 : nn.Linear()

    drop : nn.Dropout()
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        self.fc = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """ Run forward pass

        x : torch.Tensor
            Shape `(n_samples, n_patches+1, in_features)`
            Shape `(n_samples, n_patches+1, out_features)`
        x = self.fc(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    """Transformer Block

    dim : int
        Embedding dimension

    n_heads : int
        # of attention head

    mlp_ratio : float 
        Determines the hidden dimension size of MPL module with respect to dim
    qkv_bias : boolean
    p, attn_p : float
        Dropout prob

    norm1, norm2 : LayerNorm
        Layer Normalization
    attn : Attention
        Attention module
    mlp : MLP
        MLP module
    def __init__(self, dim, n_heads, mlp_ratio =4.0 , qkv_bias=True, p=0., attn_p=0.):
        self.dim = dim
        self.n_heads = n_heads
        self.norm1 = nn.LayerNorm(dim, eps = 1e-6)
        self.norm2 = nn.LayerNorm(dim, eps = 1e-6)
        self.attn = Attention(dim, n_heads=n_heads,qkv_bias=qkv_bias,attn_p=attn_p,proj_p=p)
        hidden_features = int(dim*mlp_ratio)
        self.mlp = MLP(in_features = dim, hidden_features=hidden_features,out_features=dim)

    def forward(self, x):
        """ Run forward pass

        x : torch.Tensor
            Shape `(n_samples, n_patches+1, dim)`
            Shape `(n_samples, n_patches+1, dim)`
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x

class VisionTransformer(nn.Module):
    """ Simplified implementation of ViT

    img_size : int
    patch_size : int

    in_chans : int

    n_classes : int

    embed_dim : int
        Dimensionality of token/patch embeddings
    depth : int
        # of blocks
    n_heads : int
    mlp_ratio : float

    qkv_bias : boolean

    p, attn_p : float

    patch_embed : PatchEmbed
        Instance of PatchEmbed layer

    cls_token : nn.Parameter
        Learnable parameter that will represent the first token in the sequence
    pos_emb : nn.Parameter
        Positional embedding of the cls token + all the patches. 
        It has `(n_patches + 1)` * embed_dim` elements
    pos_drop : nn.Dropout

    blocks : nn.ModuleList
        List of Block modules
    norm : nn.LayerNorm
    def __init__(self, img_size=384, patch_size=16,in_chans=3, n_classes=1000,embed_dim=768,depth=12,n_heads=12,mlp_ratio=4.,qkv_bias=True,p=0.,attn_p=0.):
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,in_chans=in_chans, embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_emb = nn.Parameter(torch.zeros(1,1+self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p)
        self.blocks = nn.ModuleList(
                    dim = embed_dim,
                for _ in range(depth)
        self.norm = nn.LayerNorm(embed_dim,eps=1e-6)
        self.head = nn.Linear(embed_dim, n_classes) # final

    def forward(self, x):
        """ Run forward pass

        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size )`
        logits : torch.Tensor
            Logits over all classes - `(n_samples, n_classes)`
        n_samples = x.shape[0]
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(n_samples,-1,-1)
        x =,x), dim=1)
        x = x + self.pos_emb
        x = self.pos_drop(x)

        for block in blocks:
            x = block(x)
        x = self.norm(x)

        cls_token_final = x[:,0]
        x = self.head(cls_token_final)

        return x

