๐ ๋ณธ ๋ฆฌ๋ทฐ๋ ViT ๋ฐ ๋ฆฌ๋ทฐ๋ฅผ ์ฐธ๊ณ ํด ์์ฑํ์ต๋๋ค.
๐ Using pure transformer for Image Recognition
๐ Fewer computational resources to train
โ ์๋ ViT์ ์ ๋ฐ์ ์ธ ์ํคํ ์ฒ ๊ตฌ์กฐ์ ๋๋ค. ์ด๋ ๊ธฐ์กด์ Transformer๋ Bert์ ๋งค์ฐ ์ ์ฌํ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ๋ ผ๋ฌธ์ ์ ์๋ ๊ธฐ์กด์ transformer ๊ตฌ์กฐ๋ฅผ ์ต๋ํ ๋น์ทํ๊ฒ ์ค๊ณํด image classification์ ์งํํ์ต๋๋ค. ๊ทธ๋ ๊ธฐ์ ๊ธฐ๋ณธ์ ์ธ Transformer๋ Bert ๊ตฌ์กฐ๋ฅผ ์ฝ์ด์ผ ์ดํดํ๊ธฐ ์์ํฉ๋๋ค.
โ ViT๋ ๊ธฐ์กด์ CNN๋ณด๋ค inductive bias๊ฐ ๋ถ์กฑํ๋ค๊ณ ์ด์ผ๊ธฐํฉ๋๋ค. ๊ทธ ๊ฒฐ๊ณผ ์ ์ ๋ฐ์ดํฐ ์ ๋ณด๋ค๋ ๋ง์ ๋ฐ์ดํฐ ์ ์์ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก๋ ๋ง์ ๋ฐ์ดํฐ์ ์ด ์๋ค๋ฉด SOTA์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ฃผ๋ฉฐ, ์ ์ ํ๋ผ๋ฏธํฐ ์๋ก ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
โ Inductive bias
โ ๋ ผ๋ฌธ์์๋ Transformer๊ฐ ์๋์ ์ผ๋ก CNN๋ณด๋ค Inductive bias์ด ๋ถ์กฑํ๋ค๊ณ ์ด์ผ๊ธฐํฉ๋๋ค. ๊ธฐ์กด์ CNN์์๋ translation equivariance์ locality๋ผ๋ ๊ฐ์ ์ด ์กด์ฌํ์ง๋ง, Transformer์์๋ ์ด๋ฅผ ๊ฐ์ ํ ์ ์์ต๋๋ค. ๊ทธ ๊ฒฐ๊ณผ ์ถฉ๋ถํ์ง ๋ชปํ data์์๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ง ๋ชปํฉ๋๋ค(ex. ImageNet)
โ ViT์ ๋ชจ๋ธ ๋์์ธ์ ๊ธฐ์กด์ Transformer์ ๊ฐ๋ฅํ ํ ์ ์ฌํ๊ฒ ๊ตฌ์ฑํ๋ค๊ณ ํฉ๋๋ค.
โ ViT๋ ๊ธฐ๋ณธ์ ์ธ ์ด๋ฏธ์ง๋ฅผ Patch๋ก ๋ถํ ํด ์งํ๋ฉ๋๋ค. ๋ณธ ๋ ผ๋ฌธ์์๋ ๊ฐ patch๋ฅผ (16x16), (14x14) ์ฌ์ด์ฆ๋ฅผ ์ฌ์ฉํ๊ณ ์์ผ๋ฉฐ, ์ด๋ ์ด๋ฏธ์ง์ resolution๊ณผ๋ ๊ด๊ณ์์ด ์ผ์ ํฉ๋๋ค. ๋ํ ๊ธฐ์กด์ Transformer ๊ตฌ์กฐ์ ๋ค๋ฅธ ์ ์ Norm์ ๋จผ์ ์ํํ๋ค๋ ์ ์ ๋๋ค.
โ ๊ธฐ์กด์ Transformer์ ๊ฒฝ์ฐ์ 1D sequence of token์ด ํ์ํ๋ค๋ฉด, ViT๋ ์ด๋ฏธ์ง(2D)๋ฅผ ๋ค๋ฃจ๊ธฐ ๋๋ฌธ์ ์์ ๊ฐ์ด Reshape๋ฅผ ํ์๋กํฉ๋๋ค. ์ ๋ํด ๊ตฌ์กฐ๋ก reshape๋ฅผ ์งํํฉ๋๋ค. ์ฌ๊ธฐ์ ๋ ์๋ณธ ์ด๋ฏธ์ง์ ๋์ด์ ๋๋น์ด๋ฉฐ, ๋ ์ฑ๋์ ์, ๋ ๊ฐ๊ฐ์ ์ด๋ฏธ์ง patch์ ํฌ๊ธฐ์ด๋ฉฐ, (=ํจ์น์ ์) ์ ๋๋ค.
โ Transformer์์ D ์ฌ์ด์ฆ์ ์์ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํจ์ผ๋ก, patches๋ค์ flattenํด D dimenstion์ผ๋ก ๋งคํํฉ๋๋ค. ์์ projection ๊ฒฐ๊ณผ๋ฅผ patch embedding์ด๋ผ๊ณ ์ด์ผ๊ธฐ ํฉ๋๋ค.
โ ๋ํ Bert์ ์ ์ฌํ๊ฒ ์์์ง์ ์ ํ์ต๊ฐ๋ฅํ ์๋ฒ ๋ฉ [class] token์ ์ถ๊ฐํฉ๋๋ค(). ๊ฐ class ํ ํฐ์ impage representation์ Transformer encoder์์ output์ผ๋ก ๋ํ๋ ๋๋ค.
โ classification head๋ pre-training๊ณผ fine-tuning์์ ์ attached ๋๋ฉฐ, ์ด๋ MLP๋ก ๊ตฌ์ฑ๋์ด์๊ณ , pre-training์์๋ ํ๋์ hidden-layer๋ก, fine-tuning์์๋ sigle linear layer๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
โ ๋ํ ์ถ๊ฐ์ ์ผ๋ก patch embedding์ position embeddingํ ๊ฐ์ผ ๋ํฉ๋๋ค. position embedding๋ ๊ธฐ์กด์ ํ์ต๊ฐ๋ฅํ 1D position embedding์ ์งํํ๋๋ฐ, ์ด๋ 2D๋ก ์งํํ์ผ๋๋ณด๋ค ์ฑ๋ฅ์ด ์ข์๊ธฐ ๋๋ฌธ์ ๋๋ค.
โ ์๋ ViT์ ์ ์ฒด์ ์ธ ๊ตฌ์กฐ๋ฅผ ์์ ์์์ ํตํด ํ์ธ ํ ์ ์์ต๋๋ค. MSA๋ ๊ธฐ์กด transformer์ multiheaded self-attention์ ์๋ฏธํ๋ฉฐ, LN์ Layernorm์ ์๋ฏธํฉ๋๋ค.
โ ์์ Inductive bias์ ์ ์์ ๋ํด ๊ฐ๋จํ ์ธ๊ธํ์ต๋๋ค. ViT์์๋ CNN๋ณด๋ค ์ ์ inductive bias๊ฐ ์กด์ฌํฉ๋๋ค. self-attention layers์์๋ globalํ๊ธฐ์, ์ค์ง MLP์์๋ง local, translationally equivariant์ด ์กด์ฌํฉ๋๋ค.
โ ViT์์๋ rawํ image patch๊ฐ ์๋, CNN์ ํตํด ์ถ์ถ๋ feature map์ input์ผ๋ก ์ฌ์ฉํ ํ์ด๋ธ๋ฆฌ๋ ๋ชจ๋ธ์ ์คํํ์ต๋๋ค. ๊ฐ๊ฐ์ patch๋ค์ 1x1๋ก input์ผ๋ก ๋ค์ด๊ฐ๊ฒ ๋ฉ๋๋ค.
โ ์ ํ์ ์ผ๋ก, ViT๋ ํฐ ๋ฐ์ดํฐ ์ ์ pre-trained ํ ํ fine-tune downstream tasks๋ฅผ ์งํํฉ๋๋ค. fine-tune์ pre-traine prediction head๋ ์ ๊ฑฐํ ํ ๋ก ์ด๋ฃจ์ด์ง zero-initialized ํผ๋ํฌ์๋ ๋ ์ด์ด๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์ถ๊ฐ๋ก K๋ downstream class์ ์๋ฅผ ์๋ฏธํฉ๋๋ค.
โ ์ผ๋ฐ์ ์ผ๋ก ์๊ฐํ๋ฉด ์ด๋ฏธ์ง์ resolution์ ๋ฐ๋ผ์ patch ์ฌ์ด์ฆ๋ฅผ ๋ค๋ฅด๊ฒ ํ๋๊ฒ ์๋๋ผ, patch ์ฌ์ด์ฆ๋ฅผ ๊ณ ์ ์ ํฉ๋๋ค. patch์ ์ฌ์ด์ฆ๋ฅผ ๊ณ ์ ํ๋ฉด sequence lengths๋ฅผ ๋ฌ๋ผ์ง๊ฒ ๋ฉ๋๋ค. ๊ทธ๋ ๋ค๋ฉด ์์์ sequence lengths๋ฅผ ์ง์ ํด์ค๋ค๋ฉด ํฌ์ง์ ์๋ฒ ๋ฉ์ ์๋ฏธ๊ฐ ์ฌ๋ผ์ง๊ฒ ๋ฉ๋๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก 2D interpolation๋ฅผ ์ฌ์ฉํด ๋์ฒดํ ์ ์๋ค๊ณ ํฉ๋๋ค.
โ Datasets์ ์๋์ ๋ฐ์ดํฐ๋ค์ ์ฌ์ฉํ์ต๋๋ค.
โ Model variants๋ ์์ ํ์ ๊ฐ์ด Base, Large, Huge๋ก ๊ตฌ๋ถํ๋ฉฐ, ViT-L/16์ด๋ผ๊ณ ํ๋ฉด Large๋ชจ๋ธ์ 16x16 input patch size๋ผ๋ ์๋ฏธ์ ๋๋ค. ๋ํ patch ์ฌ์ด์ฆ๊ฐ ์์ ์๋ก computation cost๋ ์ฆ๊ฐํ ๊ฒ์ ๋๋ค.
โ Training & Fine-tuning ๋ชจ๋ ๋ชจ๋ธ์ training์ Adam optimization(์ ์ฌ์ฉํ์ต๋๋ค. 4096๊ฐ์ batch size๋ฅผ ๊ฐ์ง๋ฉฐ, high weight decay๋ 0.1๋ก ์์ค์ ํฉ๋๋ค. Fine-tuning์์์๋ SGD with momentum์ ์ฌ์ฉํ๋ฉฐ, 512์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ์ฌ์ฉํฉ๋๋ค.
โ ๊ธฐ์กด์ SOTA ๋ชจ๋ธ๊ณผ์ ๋น๊ต์ ๋๋ค. ViT๊ฐ SOTA์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ฉด์, ํจ์ ๋ ์ ์ computational resources๊ฐ ํ์ํ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
โ ์์ ๊ทธ๋ํ์์ ๋ฐ์ดํฐ ํฌ๊ธฐ์ ๋ฐ๋ฅธ ์ฑ๋ฅ์ ํ์ธํ ์ ์์ต๋๋ค. ๋ฐ์ดํฐ์ ์ ํฌ๊ธฐ๊ฐ ์์๋๋ ๊ธฐ์กด์ SOTA ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๋ ์ข์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค๋ง, ๋ฐ์ดํฐ์ ์ ํฌ๊ธฐ๊ฐ ์ปค์ง๋ฉด ViT์ ์ฑ๋ฅ์ด ๋ ์ข์์ ํ์ธํ ์ ์์ต๋๋ค.
โ ์ด๋ ์์ ๋ง์๋๋ ธ๋ inductive bias์ ์ฐ๊ด์ง์ด ์๊ฐํ๋ค๋ฉด, ๋ฐ์ดํฐ์ ์ ํฌ๊ธฐ๊ฐ ๋ง์ผ๋ฉด inductive bias๊ฐ ํฌ๊ฒ ์ค์ํ์ง ์๋๋ค๋ ๊ฒ์ ์๋ฏธํ ์๋ ์๋ค๊ณ ์๊ฐํฉ๋๋ค.
โ Figure 7์ Left๋ ํ์ต๋ ์๋ฒ ๋ฉ ํํฐ์ ๊ตฌ์ฑ์์๋ฅผ ๋ณด์ฌ์ค๋๋ค. ๊ตฌ์ฑ ์์๋ ๊ฐ patch๋ด์์ ๋ฏธ์ธํ ๊ตฌ์กฐ๋ฅผ ์ ์ฐจ์์ ์ผ๋ก ํํํ๊ธฐ ์ํ ๊ทธ๋ด๋ฏํ ๊ธฐ๋ณธ ๊ธฐ๋ฅ๊ณผ ์ ์ฌํ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
โ Figure 7์ center๋ ๋ชจ๋ธ์ position ์๋ฒ ๋ฉ์ ์ ์ฌ์ฑ์ผ๋ก ์ด๋ฏธ์ง ๋ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๋ํ๋ด๊ณ ์์ต๋๋ค. ๋ ๊ฐ๊น์ด patch๋ ๋ ์ ์ฌํ position ์๋ฒ ๋ฉ์ ๊ฐ๋ ๊ฒฝํฅ์ด ์์ต๋๋ค.
โ Figure 7์ right๋ "attention weights"๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ ๋ณด๊ฐ ํตํฉ๋ ์ด๋ฏธ์ง ๊ณต๊ฐ์ ํ๊ท ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ ๊ฑฐ๋ฆฌ๋ฅผ ๋ณด์ฌ์ค๋๋ค. ์ฌ๊ธฐ์ ์๋ฏธํ๋ "attention weights"๋ CNN์ receptive field size์ ์ ์ฌํ ๊ฒ์ ์ ์ ์์ต๋๋ค.
โ ์ง๊ธ๊น์ง ์ ๋ฐ์ ์ธ ViT ๋คํธ์ํฌ์ ๋ํด ์์๋ดค์ต๋๋ค. ViT ๋ชจ๋ธ ๊ตฌ์กฐ๋ ์ฌ์ค ์ transformer์ bert๋ฅผ ์๊ณ ์๋ค๋ฉด, ์ดํดํ๊ธฐ ์ด๋ ต์ง ์์ ๊ตฌ์กฐ๋ผ๊ณ ์๊ฐ๋ฉ๋๋ค. ๋ ผ๋ฌธ์์๋ ์ญ์ transformer ๊ฐ๋ฅํ ํ ๋น์ทํ๊ฒ ๊ตฌ์ฑํ๋ค๊ณ ์ธ๊ธํ๊ณ ์์ต๋๋ค. ์์ ์์ฑ๋ ์คํ ๊ฒฐ๊ณผ์ธ์๋ ๋ค์ํ ์คํ ๊ฒฐ๊ณผ๋ค์ด ๋ ผ๋ฌธ์์ ์ ์๋์ด ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ ๋ ผ๋ฌธ์์ ํ์ธํด์ฃผ์๋ฉด ๋ฉ๋๋ค.