An Image Is Worth 16x16 Words : Transformers for Image Recognition at Scale

Sonny020402·2023년 7월 20일
1

Deep Learning

목록 보기
2/4

Introduction

현 시점까지 Vision task에서 SOTA는 convolution 기반의 클래식한 아키텍쳐인 ResNet-like model이다. NLP task에선 Transformer가 도입되어 성능을 잘 낼 뿐만 아니라 Pre-train→ Fine-tuning 학습 파이프라인으로 적은 양의 data로 application specific모델을 쉽게 만들수 있다. 여기서 영감을 받아 vision task에 transformer를 사용해 보고자 한다. ^~^

Mid-sized dataset인 ImageNet으로 학습하였을 때, ResNet에 비해 좋지 못한 성능을 보여 슬펐다 ㅜ~ㅜ. 그러나 CNN이 vision task에서 잘 먹히는 이유인 inductive bias = translation invariant + locality를 transformer는 가지고 있지 않다는 점을 생각하면, 같은 양의 dataset을 이용해 더 높은 성능을 내는 것은 당연히 어려운 것일지도 모른다. 그래서, larger-sized dataset인 14M-300M images를 이용해 학습했을 때 비로소 ViT는 SOTA를 이긴다.

Method

Model Architecture overview

2D image를 transformer input으로 사용하기 위해 Patch partitioning을 진행한다. original image의 resolution (H,W) → (P,P) * N개로 (where N = HW/P^2) 단순히 잘라주면 된다. 이후 Flatten하여 trainable linear projection layer를 통과시킨다. 이렇게 1차원 embedding이 만들어지면 cls token을 concat하고 positional embedding을 더해 준다. class token은 이미지 전체에 대한 context를 포함하고 있다고 가정되는 token으로써 이후 MLP head에 주입되어 classification의 대상이 된다.

Transformer encoder 파트에서는 Multi-head attention, MLP, Layernorm, residual connection이 매 블락 적용되고 이를 N(=12)개 stack한다.

z0=[xclass,xp1E;xp2E;...;xpNE]+Eposzl=MSA(LN(zl1))+zl1zl=MLP(LN(zl))+zly=LN(zL0)z_0 = [x_{class}, x^1_pE;x^2_pE; ...;x^N_pE] + E_{pos} \\ z^\prime_l = MSA(LN(z_{l-1}))+z_{l-1} \\ z_l = MLP(LN(z^\prime_l)) + z^\prime_l \\ y=LN(z^0_L)

Inductive bias

전술하였듯, ViT의 transformer layer는 global이기 때문에 CNN based model에 비해 inductive bias가 부족하다.(translation invariace, locality) 마지막 MLP head만이 locality, trans equivalence등을 가진다.

Experiments

논문 참조 ㄱㄱ

Implementation

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """ Split image intpo patches and then embed them

    Parameters
    ----------
    img_size : int
        size of the image

    patch_size : int
        size of the patch

    in_chans : int 
        # of input channels

    embed_dim : int
        embedding dimension
    
    Attributes
    ----------
    n_pathces : int 
        # of patches inside of out image
    
    prok : nn.Conv2d
        Convolutional layer that does both splitting into patches and their embedding, by setting kernel_size == stride
    """
    def __init__(self, img_size, patch_size, in_chans = 3, embed_dim = 768):
        super(PatchEmbed,self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size = patch_size,
            stride = patch_size
        )

    def forward(self, x):
        """ Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`
        
        Returns
        -------
        torch.Tensor
            Shape `(n_sampls, n_patches, embed_dim)`
        """
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1,2)
        return x

class Attention(nn.Module):
    """ Attention mechanism

    Parameters
    ----------
    dim : int
        The input and out dimension of per token features
    n_heads : int
        # of attention heads
    qkv_bias : bool
        It true, we include bias to qkv
    attn_p : float 
        Dropout probability applied to qkv tensors
    proj_p : float 
        Dropout probability applied to output tensor

    Attributes
    ----------
    scale : float
        Nomalizing constanat
    qkv : nn.Linear
        Linear Projection for the query, key, and value
    proj : nn.Lineat
        Linear projecion that takes int the concatenated output of all attention heads and maps it into new space
    attn_drop, proj_drop : nn.Dropout
        Dropout layers
    """
    def __init__(self, dim, n_heads=12, qkv_bias = True, attn_p=0., proj_p=0.):
        super(Attention,self).__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim//n_heads
        self.scale = self.head_dim** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """ Run forward pass

        Parametes
        ---------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`
        
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)` --> same dimensionality
        """

        n_samples, n_tokens, dim = x.shape
        if dim != self.dim:
            raise ValueError

        qkv = self.qkv(x) # (n_samples, n_patches + 1, dim*3)
        qkv = qkv.reshape(n_samples, n_tokens, 3 , self.n_heads, self.head_dim)
        qkv = qkv.permute(2,0,3,1,4) # (3, n_samples, n_heads, n_patches+1, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k_t = k.transpose(-2,-1) # (n_samples, n_heads, head_dim, n_patches + 1)
        dp = (q@k_t) *self.scale 
        attn = dp.softmax(dim=-1)
        attn = self.attn_drop(attn)

        weighted_avg = attn @ v # (n_samples, n_heads, n_patches+1, head_dim)
        weighted_avg = weighted_avg.transpose(1,2) 
        weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches+1, dim)

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x

class MLP(nn.Module):
    """Multilayer perceptron

    Parameters
    -----------
    in_features : int
        # of input features
    hidden_features : int
        # of nodes in the hidden layer
    out_features : int
        # of output features
    p : float
        Dropout prob
    
    Attributes
    ----------
    fc : nn.Linear()
    
    act : nn.GELU()

    fc2 : nn.Linear()

    drop : nn.Dropout()
    """
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """ Run forward pass

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches+1, in_features)`
        
        Return
        ------
        torch.Tensor
            Shape `(n_samples, n_patches+1, out_features)`
        """
        x = self.fc(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    """Transformer Block

    Parameters
    ----------
    dim : int
        Embedding dimension

    n_heads : int
        # of attention head

    mlp_ratio : float 
        Determines the hidden dimension size of MPL module with respect to dim
    
    qkv_bias : boolean
    
    p, attn_p : float
        Dropout prob

    Attribute
    ----------
    norm1, norm2 : LayerNorm
        Layer Normalization
    attn : Attention
        Attention module
    mlp : MLP
        MLP module
    """
    def __init__(self, dim, n_heads, mlp_ratio =4.0 , qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.norm1 = nn.LayerNorm(dim, eps = 1e-6)
        self.norm2 = nn.LayerNorm(dim, eps = 1e-6)
        self.attn = Attention(dim, n_heads=n_heads,qkv_bias=qkv_bias,attn_p=attn_p,proj_p=p)
        hidden_features = int(dim*mlp_ratio)
        self.mlp = MLP(in_features = dim, hidden_features=hidden_features,out_features=dim)

    def forward(self, x):
        """ Run forward pass

        Parameters
        ---------
        x : torch.Tensor
            Shape `(n_samples, n_patches+1, dim)`
        
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches+1, dim)`
        """
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x

class VisionTransformer(nn.Module):
    """ Simplified implementation of ViT

    Parameters
    ----------
    img_size : int
    
    patch_size : int

    in_chans : int

    n_classes : int

    embed_dim : int
        Dimensionality of token/patch embeddings
    depth : int
        # of blocks
    n_heads : int
    
    mlp_ratio : float

    qkv_bias : boolean

    p, attn_p : float

    Attributes
    ----------
    patch_embed : PatchEmbed
        Instance of PatchEmbed layer

    cls_token : nn.Parameter
        Learnable parameter that will represent the first token in the sequence
    pos_emb : nn.Parameter
        Positional embedding of the cls token + all the patches. 
        It has `(n_patches + 1)` * embed_dim` elements
    
    pos_drop : nn.Dropout

    blocks : nn.ModuleList
        List of Block modules
    norm : nn.LayerNorm
    """
    def __init__(self, img_size=384, patch_size=16,in_chans=3, n_classes=1000,embed_dim=768,depth=12,n_heads=12,mlp_ratio=4.,qkv_bias=True,p=0.,attn_p=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,in_chans=in_chans, embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_emb = nn.Parameter(torch.zeros(1,1+self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p)
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim = embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    p=p,
                    attn_p=attn_p
                )
                for _ in range(depth)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim,eps=1e-6)
        self.head = nn.Linear(embed_dim, n_classes) # final

    def forward(self, x):
        """ Run forward pass

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size )`
        
        Returns 
        -------
        logits : torch.Tensor
            Logits over all classes - `(n_samples, n_classes)`
        """
        n_samples = x.shape[0]
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(n_samples,-1,-1)
        x = torch.cat((cls_token,x), dim=1)
        x = x + self.pos_emb
        x = self.pos_drop(x)

        for block in blocks:
            x = block(x)
        
        x = self.norm(x)

        cls_token_final = x[:,0]
        x = self.head(cls_token_final)

        return x

1개의 댓글

comment-user-thumbnail
2023년 7월 20일

너무 좋은 글이네요. 공유해주셔서 감사합니다.

답글 달기