์๋
ํ์ธ์ :) ์ค๋์ ์ง๋๋ฒ ํฌ์คํ
์ ์ด์ด์ "Transferring Inductive Bias Through Knowledge Distillation" ๋
ผ๋ฌธ์ ๋ํ ์ ๋ฆฌ๋ฅผ ์ด์ด๋๊ฐ ๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. ์ด์ ํฌ์คํ
์์ ๋ณธ ๋
ผ๋ฌธ์์ ๋ค๋ฃจ๊ฒ ๋ ์ฃผ์ ๊ฐ๋
๋ค์ธ Knowledge Distillation
๊ณผ Inductive Bias
์ ๋ํ ์ค๋ช
์ ํด๋ณด์๋๋ฐ์. ์ด๋ฒ ํฌ์คํ
์์๋ ํด๋น ๊ธฐ๋ฒ๋ค์ ์ ์ฉํ์ฌ ์ ์๊ฐ ์ํํ ์คํ๋ค ์ค ์ฒซ๋ฒ์งธ ์๋๋ฆฌ์ค์ ๋ํด์ ์ด๊ฐธ๊ธฐ๋ฅผ ํ์ด๊ฐ๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
์ด์ ํฌ์คํธ๊ฐ ๊ถ๊ธํ์ ๋ถ์ ์ฌ๊ธฐ๋ฅผ ํตํด ํ์ธํด ๋ณด์ค ์ ์์ต๋๋ค.
๋ณธ ๋ ผ๋ฌธ์ "Knowledge Distillation์์ Teacher Model์ด Student Model์ ์ ํ๋ Dark Knowledge์ ๊ณผ์ฐ Inductive Bias์ ๋ํ ์ ๋ณด๊ฐ ์กด์ฌํ ๊น?" ๋ผ๋ ์ง๋ฌธ์์ ๋น๋กฏ๋ ์๋ฌธ์ ์ ํ์ธํ๊ธฐ ์ํด ๋๊ฐ์ง ์๋๋ฆฌ์ค๋ฅผ ๊ฐ์ง๊ณ ์คํ์ ์ ๊ฐํฉ๋๋ค. ์ฒซ ๋ฒ์งธ ์๋๋ฆฌ์ค๋ RNNs(Teacher Model)๊ณผ Transformers(Student Model)๋ฅผ, ๊ทธ๋ฆฌ๊ณ ๋ ๋ฒ์งธ ์๋๋ฆฌ์ค๋ CNNs(Teacher Model)๊ณผ MLPs(Student Model)๋ฅผ ๋น๊ตํฉ๋๋ค.
๋ณธ ์ฐ๊ตฌ๋ (1) ์ ๋ง ์ ์ ๋ชจ๋ธ๋ค์ด ๊ฐ์ง๊ณ ์๋ Inductive Bias๊ฐ ์ผ๋ง๋ ์ ์๋ฏธํ๊ฐ๋ฅผ ๋ณด์ฌ์ฃผ๊ฐ, (2) ์ ์ ๋ชจ๋ธ์๊ฒ ์ง์์ ์ ์ ๋ฐ์ ํ์ ๋ชจ๋ธ์ด ์ ๋ง ์ ์ ๋ชจ๋ธ๊ณผ ์ ์ฌํ ํ์ต์ ๊ฒฐ๊ณผ๋ฌผ์ ๋ณด์ฌ์ฃผ๋ ๊ฐ ๋ฅผ ๋ณด์ฌ์ฃผ๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ์คํ์ ์งํํ์์ต๋๋ค.
์ด๋ฒ ํฌ์คํ ์์๋ ์ฒซ๋ฒ์งธ ์๋๋ฆฌ์ค(RNNs vs Transformers)์ ๋ํด ๋ค๋ค๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
๋จผ์ ์ฒซ๋ฒ์งธ ์๋๋ฆฌ์ค๋ RNN์ค ๋ํ์ ์ธ ๋ชจ๋ธ์ธ LSTM๊ณผ Transformer๋ฅผ ๋น๊ตํฉ๋๋ค. ๋ ๋ชจ๋ธ ๋ชจ๋ Natural Language Processing(์์ฐ์ด ์ฒ๋ฆฌ)์์ ๋ง์ด ์ฌ์ฉ๋๋ ๋ชจ๋ธ์ด๋ฉฐ, Transformer๋ LSTM์ ๋นํด ๋น๊ต์ ์ต์ ์ ๋์จ ๋ ผ๋ฌธ์ผ๋ก ํ์ต ๋ฐ์ดํฐ๊ฐ ์ถฉ๋ถํ ๋ง์ผ๋ฉด ์๋ง์ ๋ชฉํ(task)์ ์์ด์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ ๋ชจ๋ธ์ ๋๋ค.
Picture from "https://jalammar.github.io/illustrated-transformer"
๋ฐ์ดํฐ๊ฐ ์ถฉ๋ถํ ๋ง๋ค๋ฉด ๋น๊ต์ ์ต๊ทผ์ ๋์จ Transformer์ด LSTM๋ณด๋ค ๋ ์ข์ ์์ธก ์ฑ๋ฅ์ ๊ฐ๋ ๊ฒ์ด ์๋ช ํฉ๋๋ค. ํ์ง๋ง, ๋ฐ์ดํฐ๊ฐ ํ์ ์ (limited)์ธ ์ํฉ์์๋ ํน์ task์์ LSTM์ด Transformer๊ฐ ๋ ๊ฐ๋ ฅํ ๋ชจ๋ธ์ด๋ผ๋ ์ฐ๊ตฌ๊ฐ ์กด์ฌํฉ๋๋ค. ๋ ผ๋ฌธ์์ ์๋ก ๋๋ task๊ฐ ๋ฐ๋ก "Subject-verb agreement prediction task"์ธ๋ฐ์. ํด๋น task๋ "Assessing the Ability of LSTMs to Learn Syntax-Sensitive Dependencies (2016)"์ด๋ผ๋ ๋ ผ๋ฌธ์์ ๊ตฌ๋ฌธ(syntax) ์ ๋ณด๋ฅผ ํ์ตํ๋๋ฐ ์ ์ฉํ๋ค๊ณ ์๊ฐ๋์์ต๋๋ค.
์ด๋ฅผ ์ดํดํ๊ธฐ ์ํด์๋ ์๋ฌธ๋ฒ์ ์กฐ๊ธ ์๊ณ ๊ฐ์ผ ํ๋๋ฐ์. ์์ด 3์ธ์นญ ํ์ฌํ ๋์ฌ์ ํํ๋ ๊ตฌ๋ฌธ ์ฃผ์ด์ ๋จธ๋ฆฌ๊ฐ ๋ณต์์ธ์ง ๋จ์์ธ์ง์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ฉฐ, ์ด๋ฌํ ๋์ฌ๋ ์ฃผ์ด ์์ ํญ์ ์์นํด ์์ง ์์๋ ๋ฉ๋๋ค. ์ข์ธก์ ์๋ ๋ฌธ์ฅ๋ค์ ๋ฐ๋ก ์์ ๋ถ์ด์๋ ์ผ์ด์ค๋ค์ด๋ฉฐ, ์ฐ์ธก์ ์๋ ์์๋ ๋จ์ด์ ธ์๋ ์ผ์ด์ค์ ๋๋ค.
์ด๋ฌํ ๋ฌธ๋ฒ ํน์ง์ ์ด์ฉํ์ฌ LSTM์ด ํ์ ๋ ๋ฆฌ์์ค๋ฅผ ๊ฐ์ง๊ณ ํ์ตํ๋ฉด, Transformer(FAN)๋ณด๋ค ์ข๋ค๋ ๊ฒ์ "The Importance of Being Recurrent for Modeling Hierarchical Structure (2018)"์์ ์ฃผ์ฅํฉ๋๋ค. ๋ณธ ๋
ผ๋ฌธ์์ ์ฌ์ฉํ ํ๊ฐ ๊ธฐ์ค์ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด Input ๋ฌธ์ฅ ์ดํ์ ์ฌ (a) ๋จ์ด๊ฐ ์ด๋ค ๊ฒ์ธ๊ฐ๋ฅผ ์์ธก ๋๋ (b) ๋จ์์ธ์ง ๋ณต์์ธ์ง ๋ฅผ ์ผ๋ง๋ ์ ์์ธกํ๋ ๊ฐ๋ฅผ ํ๊ฐ ์งํ๋ก ์ก๊ฒ ๋ฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์๋ ๊ทธ๋ํ์์ distance
๋ ๊ตฌ๋ฌธ ์ฃผ์ด์ 3์ธ์นญ ํ์ฌํ ๋์ฌ์ ๊ฑฐ๋ฆฌ, # of attracters
๋ ๊ตฌ๋ฌธ ์ฃผ์ด(์ ๋ต, ์ฐธ๊ณ ํด์ผํ ๋ช
์ฌ) ์ธ์ 3์ธ์นญ ํ์ฌํ ๋์ฌ๋ฅผ ์ ํนํ๋ ๋ค๋ฅธ ๋ช
์ฌ๋ค์ ๊ฐฏ์๋ฅผ ์๋ฏธํฉ๋๋ค.
From "The Importance of Being Recurrent for Modeling Hierarchical Structure (2018)"
์! ์ด์ ๋ฌธ๋ฒ ๊ณต๋ถ๊ฐ ๋๋ฌ์ผ๋ ๋ณธ๊ฒฉ์ ์ผ๋ก ๋ชจ๋ธ์ ๋ํด ์ด์ผ๊ธฐ ํด๋ณผ๊น์? RNN(Recurrent Neural Network)์ ์ํ์ค(Sequence) ๋ชจ๋ธ์ ๋๋ค. ์ฆ, ๊ทธ ๋ง์ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ํ์ค ๋จ์๋ก ์ฒ๋ฆฌ๋ฅผ ํ๋ค๋ ์๋ฏธ์ธ๋ฐ์. ์ฌ๊ธฐ์ ๋น๊ต ๋ชจ๋ธ๋ก ์ฌ์ฉํ๋ LSTM ์ญ์ ์ด๋ฌํ RNN์ ๊ทผ๋ณธ์ผ๋ก ํ๋ ๋ชจ๋ธ์ด๋ฏ๋ก, ๊ฐ๋ณ๊ฒ RNN์ ๋ํ ๊ฐ๋ ๊ณผ RNN์ Inductive Bias(๊ท๋ฉ์ ํธํฅ)์ ๋ํด ์ด์ผ๊ธฐํ๊ณ ๋์ด๊ฐ๋๋ก ํ๊ฒ ์ต๋๋ค.
์๋ ๊ทธ๋ฆผ์ ์ฐธ๊ณ ํ๋ฉด์ ์ค๋ช
๋๋ฆฌ๋๋ก ํ๊ฒ ์ต๋๋ค. ํด๋น ๊ทธ๋ฆผ์์ ๊ณตํต์ ์ธ ๋ถ๋ถ์ ๋ํด ๋จผ์ ์ค๋ช
๋๋ฆฌ์๋ฉด, ๋นจ๊ฐ์ ๋ฐ์ค๋ input, ํ๋์ ๋ฐ์ค๋ output, ์ด๋ก์ ๋ฐ์ค๋ (hidden) state์
๋๋ค.
Picture from CS231n
one-to-one
: Vanila Neural Network
์ฐ๋ฆฌ๊ฐ ํต์์ ์ผ๋ก ์๊ณ ์๋ ๋ด๋ด๋คํธ์ํฌ๋ก, ํ๋์ input์ ํ๋์ output์ด ๋์๋๋ ๊ตฌ์กฐ์
๋๋ค.
one-to-many
: Recurrent Neural Network
ํ๋์ input์ ์ฌ๋ฌ ๊ฐ์ output์ด ๋์๋๋ ๊ตฌ์กฐ๋ก, ๋ํ์ ์ธ ์์๋ก ์ด๋ฏธ์ง๊ฐ ํ๋๊ฐ ๋ค์ด๊ฐ์ ๋ ์ด๋ฅผ ์ค๋ช
ํ๋ ๋ฌธ์ฅ(sequence of words)์ด ๊ฒฐ๊ณผ๋ก ๋์ค๋ Image Captioning Task๊ฐ ์กด์ฌํฉ๋๋ค.
many-to-one
: Recurrent Neural Networkmany-to-many
: Recurrent Neural Network์ด๋ ๋ฏ RNN์ ๋งค Timestep๋ง๋ค ์๋ก์ด input์ด ๋ค์ด์ค๋ฉด, ์ด๋ฅผ fucntion์ ํต๊ณผ์ํค๊ณ state๋ฅผ ์ ๋ฐ์ดํธํ๊ฒ ouput์ ๋ฐํํ๊ฒ ๋ฉ๋๋ค. ์ด๋, function์ ์ ๋ถ ๋์ผํ๋ฉฐ, ๋์ผํ parameter๋ฅผ ์ฌ์ฉํฉ๋๋ค. ๋ํ, ๋ค์ state๋ ์ด์ state ์ ๋ณด๋ง์ ๋ฐ์ ์ ์์ผ๋ฉฐ, ๋ ์ด์ ์ ๋ณด๋ ๋ฏธ๋ ์ ๋ณด๋ ์ ๋ ๋ณผ ์ ์์ต๋๋ค.
์ด๋ฌํ ์์์ ์๊ฐํ RNN์ ํน์ง๋ค์ด ๋ฐ๋ก RNN์ inductive bias์ ๋๋ค. RNN์ inductive bias๋ ๋ค์๊ณผ ๊ฐ์ด ์ด 3๊ฐ์ง๋ก ๋ ผ๋ฌธ์์ ์ด์ผ๊ธฐํ๊ณ ์๋๋ฐ ์ด๋ ๊ฐ๊ฐ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Sequential-lity
: ๋ชจ๋ธ์ ๋ฃ์ด์ฃผ๋ ๋ฐ์ดํฐ๋ค์ด ์์ฐจ์ ์ผ๋ก ๋ค์ด์ค๋๋ก ๊ฐ์ ํ๋ "์์ฐจ์ฑ"
Memory Bottleneck
: ํด๋น timestamp ๋ฐ๋ก ์ด์ ์ hidden state์ ๋ณด๋ง์ ๋ชจ๋ธ์ด ๋ฐ์ ์ ์๊ธฐ ๋๋ฌธ์, ํด๋น hidden state๊ฐ ๋ ์ด์ ๊ณผ๊ฑฐ์ ๋ด์ฉ๊น์ง ์ ๋ถ ํจ์ถ์ ์ผ๋ก ๊ฐ์ถ๋๋ก ๊ฐ์ ํ๋ "๋ฉ๋ชจ๋ฆฌ์ ๋ณ๋ชฉ์ฑ"
Recursion
: ๋ชจ๋ ํจ์๊ฐ ๋์ผํ๋๋ก ๊ฐ์ ํ๋ "์ฌ๊ท์ฑ"
Transformer๋ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ์ธ์ฝ๋์์ ์
๋ ฅ ์ํ์ค(ex. I am a student)๋ฅผ ์
๋ ฅ๋ฐ๊ณ , ๋์ฝ๋์์ ์ถ๋ ฅ ์ํ์ค(ex. Je suis รฉtudiant)๋ฅผ ์ถ๋ ฅํ๋ ์ธ์ฝ๋-๋์ฝ๋ ๊ตฌ์กฐ์ ๋ชจ๋ธ์
๋๋ค. ๋ณธ ๋ชจ๋ธ์ RNN์ฒ๋ผ ์์ฐจ์ ์ผ๋ก ๋จ์ด๊ฐ ๋ค์ด๊ฐ์ง ์์๋ Self-Attention
๊ณผ Feed Forward Neural Network
๋ง์ผ๋ก๋ ์ข์ ์ฑ๋ฅ์ ๋ผ ์ ์๋ค๋ ๊ฒ์ ๋ณด์ฌ์ค ํ๊ธฐ์ ์ธ ๋
ผ๋ฌธ(๋ฐฉ๋ฒ๋ก )์
๋๋ค. ๋ณธ ํฌ์คํธ์์ Transformer์ ๋ํ ๊ฐ๋
์ ๋ฐ์ ๋ค๋ฃจ๊ธฐ์๋ ๋๋ฌด ๊ธธ์ด์ง๊ธฐ ๋๋ฌธ์ RNN์ ๋นํด ์ ์ฝ์ด ์ ๋ค๋ ์ ๋๋ง ์ดํดํ์๊ณ ๋์ด๊ฐ๋ฉด ์ข์ ๊ฒ ๊ฐ์ต๋๋ค.
Picture from "๋ฅ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด์ฒ๋ฆฌ ์ ๋ฌธ (์ ์์ค)"
Transformer์ ๊ฒฝ์ฐ, RNN์ ๋นํด ์ ์ฝ ๋๋ Inductive Bias๊ฐ ํจ์ฌ ์ฝํ ์ด์ ๋ ์๋์ ๊ฐ์ต๋๋ค.
Picture from "https://jalammar.github.io/illustrated-transformer"
Transformer๋ ํ ํฐ๋ค์ ์์น ์ ๋ณด๋ฅผ ์๋ฐฐ๋ฉํ๊ธฐ ์ํด positional encoding ์ ๋ํด์ฃผ๋ ๋ฐ ์ด๋ ๋จ์ํ sinํจ์์ cosํจ์๋ก ๋์ถ๋ ๋ฒกํฐ๊ฐ์ผ๋ก, ๋ชจ๋ธ ๋จ์์ ๊ฐ์ ์ ๋ก ๋ฐ์ดํฐ๋ฅผ ์์ฐจ์ ์ผ๋ก ๋ฐ๊ฒํ๋ RNN๊ณผ ๊ฐ์ Sequential-lity
๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค.
Transformer๋ ์ ์ฒด ํ ํฐ์ ๋ํ ์ ๋ณด๋ฅผ Self-attention์ ํตํด ์ ๋ฐ์ ์ผ๋ก ๋ฐ์ ์ ์๊ธฐ์, ์ด์ timestamp์ hidden-state๋ง์ ์ ๋ฌ๋ฐ์ ์ ์๋ RNN๊ณผ ๊ฐ์ Memory Bottleneck
์ด ์กด์ฌํ์ง ์์ต๋๋ค.
Transformer๋ Encoder์์ Decoder๋ก ํ๋ฒ์ ๊ฐ๋ ๊ตฌ์กฐ์ด๋ฏ๋ก, ๊ฐ์ ํจ์๊ฐ ์ฐ์์ ์ผ๋ก ์ฌ์ฉ๋๋ RNN๊ณผ ๊ฐ์ recursion
์ด ์กด์ฌํ์ง ์์ต๋๋ค.
Dataset
Performance Metric
Learning Objective
๊ฐ๊ฐ์ Objective์ ๋ฐ๋ฅธ ์คํ ๋ชจ๋ธ๊ตฐ
Language Modelling (LM) Setup :
1. LSTM
: Base LSTM
2. Small LSTM
: LSTM with smaller parameter
3. Transformer
: Base Transformer
4. Small Transformer
: Transformer with smaller parameter
Classification Setup :
1. LSTM
: Base LSTM (Sequentiality + Memory bottleneck + Recursion)
2. Transformer
: Base Transformer (No Inductive Bias)
3. Transformer-seq
: Base Transformer์ Sequentiality๋ฅผ ๊ฐ์ ๋ก ์ถ๊ฐํด์ค ๋ชจ๋ธ (Sequentiality)
4. Universal Transformer-seq
: Transformer-seq์ Recursion์ ๊ฐ์ ๋ก ์ถ๊ฐํด์ค ๋ชจ๋ธ (Sequentiality + Recursion)
๋ณธ ๋ ผ๋ฌธ์ ์คํ์ ํตํด ๋ค์ 2๊ฐ์ง๋ฅผ ์ฆ๋ช ํ๊ณ ์ ํ์์ต๋๋ค:
Without Distillation
โป ์ฌ๊ธฐ์ ์ ๊น!
(์ฐธ๊ณ ) Calibration Error๋, ๋ชจํ์ ์ถ๋ ฅ ๊ฐ์ด ์ค์ confidence๋ฅผ ๋ฐ์ํ๋๋ก ๋ง๋๋ ๊ฒ์ ๋๋ค. ์๋ฅผ ๋ค์ด, X์ Y์ ๋ํ ๋ชจ๋ธ์ ์ถ๋ ฅ์ด 0.8์ด ๋์์ ๋, 80 % ํ๋ฅ ๋ก Y์ผ ๊ฒ์ด๋ผ๋ ์ ๋ขฐ(์๋ฏธ)๋ฅผ ๊ฐ๋๋ก ๋ง๋๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. ์ด๋ฌํ Calibration Error๋ Bining์ด๋ผ๋ ์์ ํตํด M๊ฐ์ Bin์ ๋ํ์ฌ ๊ฐ๊ฐ์ Bin๋ง๋ค์ Calibration Error๋ฅผ ๊ตฌํ์ฌ ํ๊ท (Expectation)์ ๋ด์ ์ฐ์ถํ๋ฏ๋ก, Expectated Calibration Error(ECE)๋ก ํ๊ฐ ์งํ๊ฐ ์ฌ์ฉ๋ฉ๋๋ค.
With Distillation
LSTM์ ์ง์(Knowledge)๋ฅผ Transformer์๊ฒ ์ ๋ฌํด์ค์ผ๋ก์จ ๊ธฐ์กด์ Transformer(๋นจ๊ฐ)์ ๋นํ์ฌ ์ฑ๋ฅ์ด ํฅ์๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.(ํ๋, ํ๋)
๋ํ, ๊ธฐ์กด์ Tranformer์ perplexity๋ teacher model์ ๊ทผ์ฌํ๊ฒ ๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
Language Modelling(LM) Setup
Classification Setup
๋ง์ง๋ง์ผ๋ก Multidimensional Scaling(MDS) ํตํด ๋ชจ๋ธ penultimate layer์์์ representation์ ์๊ฐํ ํด๋ณธ ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ์ต๋๋ค.
๊ธฐ์กด Transformer ๋ชจ๋ธ์ Variance๊ฐ ๋์ ๋ฐ๋ฉด Inductive Bias๊ฐ ๋์ ์์ผ๋ก ๋ชจ๋ธ์ Variance๊ฐ ๋ฎ๊ณ , Distillation์ ์ํํ๋ฉด ๊ธฐ์กด ๋ชจ๋ธ์ด Teacher ๋ชจ๋ธ๊ณผ ์ ์ฌํด์ง๋ฉฐ Variance ๋ํ ๊ฐ์ํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
์ค๋์ ์๋๋ฆฌ์ค 1์ ๋ํ์ฌ ์์ธํ๊ฒ ๋ค๋ฃจ์ด๋ณด์๋๋ฐ์. ๋น ๋ฅธ ์์ผ ๋ด์ ์๋๋ฆฌ์ค 2๊น์ง ์ ๋ก๋ํ๋๋ก ํ๊ฒ ์ต๋๋ค :D
๊ธด ๊ธ ์ฝ์ด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค ใ ใ ใ