์๋ ํ์ธ์! ์ค๋ ๋ ผ๋ฌธ๋ฆฌ๋ทฐ, ์ฝ๋๋ฆฌ๋ทฐํด๋ณผ ๋ ผ๋ฌธ์ "An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale" ๋ก, ์ปดํจํฐ ๋น์ ์์ Transformer์ Attention์ด ์ฐ์ด๊ฒ ๋ ๊ฒฐ์ ์ ๊ณ๊ธฐ(?)๊ฐ ๋ ๋ ผ๋ฌธ์ ๋๋ค. ์ต๊ทผ ์ด์ชฝ ๋ถ์ผ์ ๊ด์ฌ์ด ๋ง๋ค ๋ณด๋ ์ค๋์ ์ด ๋ ผ๋ฌธ์ ๋ฆฌ๋ทฐํ๊ฒ ๋์์ต๋๋ค.
๊ธฐ์กด RNN ๊ธฐ๋ฐ seq2seq ๋ชจ๋ธ์์๋ ์ด์ ์์ ์ ์ฐ์ฐ์ด ๋๋๊ธฐ ์ ์๋ ๋ค์ ์์ ์ ์ฐ์ฐ์ด ๋ถ๊ฐ๋ฅํ์ฌ ๋ณ๋ ฌํ(parallelize) ๋ ์ฐ์ฐ์ฒ๋ฆฌ๊ฐ ๋ถ๊ฐ๋ฅํ์ต๋๋ค.
RNN๊ตฌ์กฐ๋ ๊ณ ์ง์ ๋ฌธ์ ์ธ Long-term dependency ๋ฌธ์ ๊ฐ ๋ฐ์ํ์๊ณ , ์ด๋ ๊ณง ํ์ ์คํ (time step)์ด ๊ธธ์ด์ง ์๋ก ์ํ์ค ์ฒ๋ฆฌ์ ์ฑ๋ฅ์ด ๋จ์ด์ง์ ์๋ฏธํฉ๋๋ค.
์ด๋ฌํ ๋ฌธ์ ์ ๋ค์ ๋ณด์ํ๊ธฐ ์ํด Attention๋ง์ผ๋ก Encoder, Decoder ๊ตฌ์กฐ๋ฅผ ๋ง๋ค์ด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ชจ๋ธ์ ์ ์ํจ์ผ๋ก์จ ํ์ต ์๋๊ฐ ๋งค์ฐ ๋น ๋ฅด๋ฉฐ ์ฑ๋ฅ๋ ์ฐ์ํ Transformer๊ฐ ์ ์๋์์ต๋๋ค.
Image Recognition (Classification)์ ์ด๋ฏธ์ง๋ฅผ ์๊ณ ๋ฆฌ์ฆ์ ์ ๋ ฅ(input)ํด์ฃผ๋ฉด, ๊ทธ ์ด๋ฏธ์ง๊ฐ ์ํ๋ class lable์ ์ถ๋ ฅ(output)ํด์ฃผ๋ task๋ฅผ ์๋ฏธํฉ๋๋ค.
์๋ ๊ทธ๋ฆผ ์ฒ๋ผ ๊ณ ์์ด ์ฌ์ง์ ๋ฃ์ด์ฃผ๋ฉด ๊ณ ์์ด ๋ผ๊ณ ์ธ์(๋ถ๋ฅ)ํด๋
๋๋ค.
์์ ๊ทธ๋ฆผ์ฒ๋ผ ์ฌ์ง๊ป CV(Computer Vision) ๋๋ฉ์ธ์์๋ CNN(Convolutional Neural Network)๋ฅผ ์ฌ์ฉํ ๋ชจ๋ธ๋ค์ด ๋ง์ด ์ฌ์ฉ๋์ด ์ค๊ณ ์์์ต๋๋ค. (Ex. ResNet, UNet, EfficientNet ๋ฑ)
ํ์ง๋ง, NLP(Natural Language Processing) ๋๋ฉ์ธ์์์ Self-Attention๊ณผ Transformer์ ์ฑ์ฅ์ผ๋ก ์ธํด CNN๊ณผ Attention์ ํจ๊ป ์ด์ฉํ๋ ค๋ ์ถ์ธ๊ฐ ์ฆ๊ฐํ๊ณ ์์ต๋๋ค. ๋ณธ ๋ ผ๋ฌธ(์ฐ๊ตฌ) ์ญ์ ๊ทธ๋ฌํ ์๋ ์ค ํ๋์ ๋๋ค.
import torch
import torch.nn as nn
from torch import Tensor
from torchsummary import summary
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import numpy as np
import os
import copy
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
%pip install einops
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.long
def pair(t):
return t if isinstance(t, tuple) else (t, t)
PreNorm
Class# Define PreNorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
FeedForward
Class# Define FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
Attention
Class# Define Attention
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
Transformer
Class# Define Transformer
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
[0,*]
์ด ์๋ ๊ฒ์ ํ์ธํ์ค ์ ์๋ ๋ฐ ์ด๋ ์ ์ฒด ์ด๋ฏธ์ง์ ๋ชจ๋ ์ ๋ณด๋ฅผ ๋ด๊ณ ์๋ ํ ํฐ์ด๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค. (A.K.A. [CLS]
ํ ํฐ) ์ด๋ฒ ๋จ๊ณ์์๋ ์ด๋ฌํ [CLS]
ํ ํฐ์ ์ฌ์ฉํ์ฌ MLP(Multi-layer Perceptron)์ ํ์ Classificatin์ ์ํํ๊ฒ ๋ฉ๋๋ค.๐ ์ฌ๊ธฐ์ ์ ๊น!
einops
๋ผ์ด๋ธ๋ฌ๋ฆฌ?
- Einstein notation ์ ๋ณต์กํ ํ ์ ์ฐ์ฐ์ ํ๊ธฐํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ๋ฅ๋ฌ๋์์ ์ฐ์ด๋ ๋ง์ ์ฐ์ฐ์ Einstein notation ์ผ๋ก ์ฝ๊ฒ ํ๊ธฐํ ์ ์์ต๋๋ค.
- einops (https://github.com/arogozhnikov/einops)๋ pytorch, tensorflow ๋ฑ ์ฌ๋ฌ ํ๋ ์์ํฌ๋ฅผ ๋์์ ์ง์ํ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์ด๋ฌํ Einstein notation์ ์ฌ์ฉํ ์ ์๊ฒํฉ๋๋ค.
๐ ์ฌ๊ธฐ์ ์ ๊น!
Rearrange
ํจ์?
- Rearrange ํจ์๋ shape๋ฅผ ์ฝ๊ฒ ๋ณํํด์ฃผ๋ ํจ์๋ผ๊ณ ์๊ฐํ๋ฉด ๋ฉ๋๋ค.
- ๋ฐ์ ๊ทธ๋ฆผ์ผ๋ก ์ด๋ป๊ฒ ์๋ํ๋์ง ์ง๊ด์ ์ผ๋ก ํ์ธํด๋ณด์์ฃ !
๐ ์ฌ๊ธฐ์ ์ ๊น!
einsum
ํจ์?
- Einsum ํ๊ธฐ๋ฒ์ ํน์ํ Domain Specific Language๋ฅผ ์ด์ฉํด ์ด ๋ชจ๋ ํ๋ ฌ, ์ฐ์ฐ์ ํ๊ธฐํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
- ์ฝ๊ฒ ๋งํด ์ฐ๋ฆฌ๊ฐ ๊ตฌํ๊ณ ์ถ์ ํ๋ ฌ ์ฐ์ฐ์ ์ง๊ด์ ์ผ๋ก ์ ์ํด์ ๊ตฌํ๊ฒ ํด์ฃผ๋ ํจ์์ ๋๋ค.
- ๋ช ๊ฐ์ง ์์๋ก ์ดํด๋ณด์์ฃ (given X(matrix), Y(matrix))
- Transpose :np.einsum("ij->ji", X)
- Matrix sum :np.einsum("ij->", X)
- Matrix row sum :np.einsum("ij->i", X)
- Matrix col sum :np.einsum("ij->j", X)
- Matrix Multiplication :np.einsum('ij,j->i', X, Y)
- Batched Matrix Multiplication :np.einsum('bik,bkj->bij', X, Y)
# ViT Class
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__() # super()๋ก ๊ธฐ๋ฐ ํด๋์ค์ __init__ ๋ฉ์๋ ํธ์ถ
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
# assert ๋ฌธ : ๋ค์ ์กฐ๊ฑด์ด True๊ฐ ์๋๋ฉด AssertError๋ฅผ ๋ฐ์
# patch size ์กฐ๊ฑด
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
# pooling ์กฐ๊ฑด
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # 3 -> 2
nn.Linear(patch_dim, dim), # Linear Projection
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # position embedding ์ด๊ธฐํ
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # [CLS] ํ ํฐ ์ด๊ธฐํ
self.dropout = nn.Dropout(emb_dropout) # Dropout ์ ์
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) # Transformer ์ ์ธ
self.pool = pool
self.to_latent = nn.Identity() # ๋๋ค์ฑ ์ ๊ฑฐ
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, ICLR 2019 - Alexey Dosovitskiy et. al.
๋ฅ ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด ์ฒ๋ฆฌ ์ ๋ฌธ - ์ ์์ค ์ธ 1์ธ (https://wikidocs.net/book/2155)
The Illustrated Transformer -
Jay Alammar (https://jalammar.github.io/illustrated-transformer)
ViT Source Code - lucidrains (https://github.com/lucidrains/vit-pytorch/blob/64a2ef6462bde61db4dd8f0887ee71192b273692/vit_pytorch/vit.py)