[Paper Review] An Image Is Worth 16x16 Words : Transformers for Image Recognition at Scale (Vision Transformer)

์„œ์ฟ ยท2021๋…„ 11์›” 6์ผ
1
post-thumbnail

์„ ์ • ์ด์œ 

์•ˆ๋…•ํ•˜์„ธ์š”! ์˜ค๋Š˜ ๋…ผ๋ฌธ๋ฆฌ๋ทฐ, ์ฝ”๋“œ๋ฆฌ๋ทฐํ•ด๋ณผ ๋…ผ๋ฌธ์€ "An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale" ๋กœ, ์ปดํ“จํ„ฐ ๋น„์ „์—์„œ Transformer์™€ Attention์ด ์“ฐ์ด๊ฒŒ ๋œ ๊ฒฐ์ •์  ๊ณ„๊ธฐ(?)๊ฐ€ ๋œ ๋…ผ๋ฌธ์ž…๋‹ˆ๋‹ค. ์ตœ๊ทผ ์ด์ชฝ ๋ถ„์•ผ์— ๊ด€์‹ฌ์ด ๋งŽ๋‹ค ๋ณด๋‹ˆ ์˜ค๋Š˜์€ ์ด ๋…ผ๋ฌธ์„ ๋ฆฌ๋ทฐํ•˜๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

๋…ผ๋ฌธ๋ฆฌ๋ทฐ

Background

(Self) Attention

  • Attention์˜ ๊ธฐ๋ณธ ์•„์ด๋””์–ด๋Š” ๋””์ฝ”๋”์—์„œ ์ถœ๋ ฅ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๋งค ์‹œ์ (time step)๋งˆ๋‹ค, ์ธ์ฝ”๋”์—์„œ์˜ ์ „์ฒด ์ž…๋ ฅ ๋ฌธ์žฅ์„ ๋‹ค์‹œ ํ•œ ๋ฒˆ ์ฐธ๊ณ ํ•ฉ๋‹ˆ๋‹ค.
  • ๋‹จ, ์ „์ฒด ์ž…๋ ฅ ๋ฌธ์žฅ์„ ์ „๋ถ€ ๋‹ค ๋™์ผํ•œ ๋น„์œจ๋กœ ์ฐธ๊ณ ํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ํ•ด๋‹น ์‹œ์ ์—์„œ ์˜ˆ์ธกํ•ด์•ผ ํ•  ๋‹จ์–ด์™€ ์—ฐ๊ด€์ด ์žˆ๋Š” ์ž…๋ ฅ ๋‹จ์–ด ๋ถ€๋ถ„์„ ์ข€ ๋” ์ง‘์ค‘(attention)ํ•ด์„œ ๋ณด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

Q/K/V-1

  • ์ฃผ์–ด์ง„ '์ฟผ๋ฆฌ(Query)'์— ๋Œ€ํ•ด์„œ ๋ชจ๋“  'ํ‚ค(Key)'์™€์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ฐ๊ฐ ๊ตฌํ•ฉ๋‹ˆ๋‹ค. - ๊ทธ๋ฆฌ๊ณ  ๊ตฌํ•ด๋‚ธ ์œ ์‚ฌ๋„๋ฅผ ๊ฐ€์ค‘์น˜๋กœ ํ•˜์—ฌ ํ‚ค์™€ ๋งตํ•‘๋˜์–ด ์žˆ๋Š” ๊ฐ๊ฐ์˜ '๊ฐ’(Value)'์— ๋ฐ˜์˜ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์œ ์‚ฌ๋„๊ฐ€ ๋ฐ˜์˜๋œ '๊ฐ’(Value)'์„ ๋ชจ๋‘ ๊ฐ€์ค‘ํ•ฉํ•˜์—ฌ ๋ฆฌํ„ดํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.
  • '์ฟผ๋ฆฌ(Query)', 'ํ‚ค(Key)', '๊ฐ’(Value)'์˜ ์ •์˜๋Š” ์˜์–ด๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

Q/K/V-2

Q/K/V-3

Transformer

  • ๊ธฐ์กด RNN ๊ธฐ๋ฐ˜ seq2seq ๋ชจ๋ธ์—์„œ๋Š” ์ด์ „ ์‹œ์ ์˜ ์—ฐ์‚ฐ์ด ๋๋‚˜๊ธฐ ์ „์—๋Š” ๋‹ค์Œ ์‹œ์ ์˜ ์—ฐ์‚ฐ์ด ๋ถˆ๊ฐ€๋Šฅํ•˜์—ฌ ๋ณ‘๋ ฌํ™”(parallelize) ๋œ ์—ฐ์‚ฐ์ฒ˜๋ฆฌ๊ฐ€ ๋ถˆ๊ฐ€๋Šฅํ–ˆ์Šต๋‹ˆ๋‹ค.

  • RNN๊ตฌ์กฐ๋Š” ๊ณ ์งˆ์  ๋ฌธ์ œ์ธ Long-term dependency ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜์˜€๊ณ , ์ด๋Š” ๊ณง ํƒ€์ž„ ์Šคํ…(time step)์ด ๊ธธ์–ด์งˆ ์ˆ˜๋ก ์‹œํ€€์Šค ์ฒ˜๋ฆฌ์˜ ์„ฑ๋Šฅ์ด ๋–จ์–ด์ง์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

  • ์ด๋Ÿฌํ•œ ๋ฌธ์ œ์ ๋“ค์„ ๋ณด์™„ํ•˜๊ธฐ ์œ„ํ•ด Attention๋งŒ์œผ๋กœ Encoder, Decoder ๊ตฌ์กฐ๋ฅผ ๋งŒ๋“ค์–ด ์‹œํ€€์Šค๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ชจ๋ธ์„ ์ œ์•ˆํ•จ์œผ๋กœ์จ ํ•™์Šต ์†๋„๊ฐ€ ๋งค์šฐ ๋น ๋ฅด๋ฉฐ ์„ฑ๋Šฅ๋„ ์šฐ์ˆ˜ํ•œ Transformer๊ฐ€ ์ œ์•ˆ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

Encoder/Decoder

  • ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์—ฌ๋Ÿฌ๊ฐœ์˜ Head๋ฅผ ์‚ฌ์šฉํ•˜๋Š” Multi-head Attention์„ ํ†ตํ•ด ๋‹ค์–‘ํ•œ aspect์— ๋Œ€ํ•ด์„œ ๋ชจ๋ธ์ด ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

MHSA

Image Recognition (Classification)

  • Image Recognition (Classification)์€ ์ด๋ฏธ์ง€๋ฅผ ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ์ž…๋ ฅ(input)ํ•ด์ฃผ๋ฉด, ๊ทธ ์ด๋ฏธ์ง€๊ฐ€ ์†ํ•˜๋Š” class lable์„ ์ถœ๋ ฅ(output)ํ•ด์ฃผ๋Š” task๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

  • ์•„๋ž˜ ๊ทธ๋ฆผ ์ฒ˜๋Ÿผ ๊ณ ์–‘์ด ์‚ฌ์ง„์„ ๋„ฃ์–ด์ฃผ๋ฉด ๊ณ ์–‘์ด ๋ผ๊ณ  ์ธ์‹(๋ถ„๋ฅ˜)ํ•ด๋ƒ…๋‹ˆ๋‹ค.
    Image Classification

  • ์œ„์— ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ์—ฌ์ง€๊ป CV(Computer Vision) ๋„๋ฉ”์ธ์—์„œ๋Š” CNN(Convolutional Neural Network)๋ฅผ ์‚ฌ์šฉํ•œ ๋ชจ๋ธ๋“ค์ด ๋งŽ์ด ์‚ฌ์šฉ๋˜์–ด ์˜ค๊ณ  ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. (Ex. ResNet, UNet, EfficientNet ๋“ฑ)

  • ํ•˜์ง€๋งŒ, NLP(Natural Language Processing) ๋„๋ฉ”์ธ์—์„œ์˜ Self-Attention๊ณผ Transformer์˜ ์„ฑ์žฅ์œผ๋กœ ์ธํ•ด CNN๊ณผ Attention์„ ํ•จ๊ป˜ ์ด์šฉํ•˜๋ ค๋Š” ์ถ”์„ธ๊ฐ€ ์ฆ๊ฐ€ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋ณธ ๋…ผ๋ฌธ(์—ฐ๊ตฌ) ์—ญ์‹œ ๊ทธ๋Ÿฌํ•œ ์‹œ๋„ ์ค‘ ํ•˜๋‚˜์ž…๋‹ˆ๋‹ค.

Vision Transformer

  • Vision Transformer์˜ ๊ฐœ๋…์€ Transformer๊ฐ€ ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”์ง€ ์•„๋Š” ์‚ฌ๋žŒ๋“ค์ด๋ผ๋ฉด ์‰ฝ๊ฒŒ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ๊ทธ๋ฆผ์ด ์†”์งํžˆ ๋ณธ ๋…ผ๋ฌธ์— ์ „๋ถ€์ด๊ธฐ ๋•Œ๋ฌธ์ด์ฃ .

ViT

  • Vision Transformer๋Š” image recognition task์— ์žˆ์–ด์„œ Convolution์„ ์•„์˜ˆ ์—†์• ๊ณ , Transformer Encoder๋งŒ์„ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๊ฐ๊ฐ์˜ ์ˆœ์„œ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

0. Prerequisites

  • ํ•„์š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ IMPORT
import torch
import torch.nn as nn
from torch import Tensor
from torchsummary import summary

import torch.optim as optim
from torch.optim import lr_scheduler

import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import numpy as np
import os
import copy
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

%pip install einops
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.long
  • Define Helper Function
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

Transformer

  • Define PreNorm Class
# Define PreNorm
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
  • Define FeedForward Class
# Define FeedForward
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
  • Define Attention Class
# Define Attention
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
  • Define Transformer Class
# Define Transformer
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

step 1. Splitting Image into fixed-size patches

  • ๊ฐ€์žฅ ๋จผ์ € ์ด๋ฏธ์ง€๋ฅผ ๊ณ ์ •๋œ ์‚ฌ์ด์ฆˆ์˜ ํŒจ์น˜๋“ค๋กœ ๋ถ„ํ• ํ•˜์—ฌ ๋ชจ๋ธ์— ๋„ฃ์–ด์ค๋‹ˆ๋‹ค.

ViT1

step 2. Linearly embed each patches

  • ๊ฐ๊ฐ์˜ ์ด๋ฏธ์ง€ ํŒจ์น˜๋“ค์— ๋Œ€ํ•ด Linear Embedding์„ ์ˆ˜ํ–‰ํ•ด์ค๋‹ˆ๋‹ค. (D์ฐจ์›์œผ๋กœ)
    ViT2

step 3. Add positional embedding

  • ์ด์ œ ๊ฐ๊ฐ์˜ ์ด๋ฏธ์ง€ ํŒจ์น˜๋“ค์ด ์–ด๋–ค ์œ„์น˜์— ์žˆ๋Š”๊ฐ€์— ๋Œ€ํ•œ ์ •๋ณด๋„ ๋ชจ๋ธ์— ๋„ฃ์–ด์ฃผ์–ด์•ผ๊ฒ ์ฃ ? ์ด๋Ÿฐ ์œ„์น˜์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ์šฐ๋ฆฌ๋Š” position embedding์ด๋ผ๊ณ  ํ•˜๋ฉฐ, ์•ž์—์„œ ๊ตฌํ•œ Embedding์— ๋ถ™์—ฌ์ฃผ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

ViT3

step 4. Feed embedding vector into Transformer Encoder

  • ๊ฐ๊ฐ์˜ ์ด๋ฏธ์ง€ ํŒจ์น˜๋“ค์— ๋Œ€ํ•œ ์œ„์น˜ ์ •๋ณด์™€ ์ž„๋ฐฐ๋”ฉ ๊ฐ’์„ Transformer Encoder๋กœ ๋„ฃ์–ด์ค๋‹ˆ๋‹ค. Transformer Encoder๋Š” ์•„๋ž˜ ๊ทธ๋ฆผ(์šฐ์ธก)๊ณผ ๊ฐ™์ด ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

ViT4

step 5. Use [CLS] token for Classification

  • ์ด๋ฏธ ๋ˆˆ์น˜ ์ฑ„์‹  ๋ถ„๋“ค๋„ ์žˆ๊ฒ ์ง€๋งŒ, Transformer Encoder์— ๋“ค์–ด๊ฐ„ ๊ฐ๊ฐ์˜ ์ด๋ฏธ์ง€ ํŒจ์น˜๋“ค์— ๋Œ€ํ•œ ์œ„์น˜ ์ •๋ณด์™€ ์ž„๋ฐฐ๋”ฉ ๊ฐ’ ์™ธ์—๋„ ์•ž์— [0,*]์ด ์žˆ๋Š” ๊ฒƒ์„ ํ™•์ธํ•˜์‹ค ์ˆ˜ ์žˆ๋Š” ๋ฐ ์ด๋Š” ์ „์ฒด ์ด๋ฏธ์ง€์˜ ๋ชจ๋“  ์ •๋ณด๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ํ† ํฐ์ด๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. (A.K.A. [CLS]ํ† ํฐ) ์ด๋ฒˆ ๋‹จ๊ณ„์—์„œ๋Š” ์ด๋Ÿฌํ•œ [CLS]ํ† ํฐ์„ ์‚ฌ์šฉํ•˜์—ฌ MLP(Multi-layer Perceptron)์— ํƒœ์›Œ Classificatin์„ ์ˆ˜ํ–‰ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

ViT5

๐Ÿ“Œ ์—ฌ๊ธฐ์„œ ์ž ๊น! einops ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ?

  • Einstein notation ์€ ๋ณต์žกํ•œ ํ…์„œ ์—ฐ์‚ฐ์„ ํ‘œ๊ธฐํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ๋”ฅ๋Ÿฌ๋‹์—์„œ ์“ฐ์ด๋Š” ๋งŽ์€ ์—ฐ์‚ฐ์€ Einstein notation ์œผ๋กœ ์‰ฝ๊ฒŒ ํ‘œ๊ธฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • einops (https://github.com/arogozhnikov/einops)๋Š” pytorch, tensorflow ๋“ฑ ์—ฌ๋Ÿฌ ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ๋™์‹œ์— ์ง€์›ํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ ์ด๋Ÿฌํ•œ Einstein notation์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒํ•ฉ๋‹ˆ๋‹ค.

๐Ÿ“Œ ์—ฌ๊ธฐ์„œ ์ž ๊น! Rearrange ํ•จ์ˆ˜?

  • Rearrange ํ•จ์ˆ˜๋Š” shape๋ฅผ ์‰ฝ๊ฒŒ ๋ณ€ํ™˜ํ•ด์ฃผ๋Š” ํ•จ์ˆ˜๋ผ๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.
  • ๋ฐ‘์— ๊ทธ๋ฆผ์œผ๋กœ ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”์ง€ ์ง๊ด€์ ์œผ๋กœ ํ™•์ธํ•ด๋ณด์‹œ์ฃ !
    Rearrangeํ•จ์ˆ˜

๐Ÿ“Œ ์—ฌ๊ธฐ์„œ ์ž ๊น! einsum ํ•จ์ˆ˜?

  • Einsum ํ‘œ๊ธฐ๋ฒ•์€ ํŠน์ˆ˜ํ•œ Domain Specific Language๋ฅผ ์ด์šฉํ•ด ์ด ๋ชจ๋“  ํ–‰๋ ฌ, ์—ฐ์‚ฐ์„ ํ‘œ๊ธฐํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
  • ์‰ฝ๊ฒŒ ๋งํ•ด ์šฐ๋ฆฌ๊ฐ€ ๊ตฌํ•˜๊ณ  ์‹ถ์€ ํ–‰๋ ฌ ์—ฐ์‚ฐ์„ ์ง๊ด€์ ์œผ๋กœ ์ •์˜ํ•ด์„œ ๊ตฌํ•˜๊ฒŒ ํ•ด์ฃผ๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
  • ๋ช‡ ๊ฐ€์ง€ ์˜ˆ์‹œ๋กœ ์‚ดํŽด๋ณด์‹œ์ฃ  (given X(matrix), Y(matrix))
    - Transpose : np.einsum("ij->ji", X)
    - Matrix sum : np.einsum("ij->", X)
    - Matrix row sum : np.einsum("ij->i", X)
    - Matrix col sum : np.einsum("ij->j", X)
    - Matrix Multiplication : np.einsum('ij,j->i', X, Y)
    - Batched Matrix Multiplication : np.einsum('bik,bkj->bij', X, Y)

Vision Transformer ์ฝ”๋“œ

# ViT Class
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__() # super()๋กœ ๊ธฐ๋ฐ˜ ํด๋ž˜์Šค์˜ __init__ ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        # assert ๋ฌธ : ๋’ค์˜ ์กฐ๊ฑด์ด True๊ฐ€ ์•„๋‹ˆ๋ฉด AssertError๋ฅผ ๋ฐœ์ƒ
        # patch size ์กฐ๊ฑด
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        # pooling ์กฐ๊ฑด 
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # 3 -> 2
            nn.Linear(patch_dim, dim), # Linear Projection
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # position embedding ์ดˆ๊ธฐํ™”
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # [CLS] ํ† ํฐ ์ดˆ๊ธฐํ™”
        self.dropout = nn.Dropout(emb_dropout) # Dropout ์ •์˜

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) # Transformer ์„ ์–ธ

        self.pool = pool
        self.to_latent = nn.Identity() # ๋žœ๋ค์„ฑ ์ œ๊ฑฐ

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        
        return self.mlp_head(x)

Reference

  1. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, ICLR 2019 - Alexey Dosovitskiy et. al.

  2. ๋”ฅ ๋Ÿฌ๋‹์„ ์ด์šฉํ•œ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ์ž…๋ฌธ - ์œ ์›์ค€ ์™ธ 1์ธ (https://wikidocs.net/book/2155)

  3. The Illustrated Transformer -
    Jay Alammar (https://jalammar.github.io/illustrated-transformer)

  4. ViT Source Code - lucidrains (https://github.com/lucidrains/vit-pytorch/blob/64a2ef6462bde61db4dd8f0887ee71192b273692/vit_pytorch/vit.py)

profile
Always be passionate โœจ

0๊ฐœ์˜ ๋Œ“๊ธ€