์๋
ํ์ธ์ :) ์ค๋์ ์ง๋๋ฒ ํฌ์คํ
์ ์ด์ด์ "Transferring Inductive Bias Through Knowledge Distillation" ๋
ผ๋ฌธ์ ๋ํ ์ ๋ฆฌ๋ฅผ ์ด์ด๋๊ฐ ๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. ์ด์ ํฌ์คํ
์์ ๋ณธ ๋
ผ๋ฌธ์์ ๋ค๋ฃจ๊ฒ ๋ ์ฃผ์ ๊ฐ๋
๋ค์ธ Knowledge Distillation
๊ณผ Inductive Bias
์ ๋ํ ์ค๋ช
๊ณผ RNNs vs Transformers
์ ๋ํ ์คํ์ ์งํํ Scenario 1์ ๋ํด ์ด์ผ๊ธฐ๋ฅผ ํ์ด๋ดค๋๋ฐ์.
์ด์ ํฌ์คํธ๊ฐ ๊ถ๊ธํ์ ๋ถ์ ์๋ ๋งํฌ๋ค์ ํตํด ํ์ธํ์ค ์ ์์ต๋๋ค.
1. ๋
ผ๋ฌธ์ ํ์ํ ๊ฐ๋
: Knowledge Distillation & Inductive Bias (๋งํฌ)
2. ๋
ผ๋ฌธ ์๋๋ฆฌ์ค 1 : RNNs vs Transformers (๋งํฌ)
๋ณธ ๋ ผ๋ฌธ์ "Knowledge Distillation์์ Teacher Model์ด Student Model์ ์ ํ๋ Dark Knowledge์ ๊ณผ์ฐ Inductive Bias์ ๋ํ ์ ๋ณด๊ฐ ์กด์ฌํ ๊น?" ๋ผ๋ ์ง๋ฌธ์์ ๋น๋กฏ๋ ์๋ฌธ์ ์ ํ์ธํ๊ธฐ ์ํด ๋๊ฐ์ง ์๋๋ฆฌ์ค๋ฅผ ๊ฐ์ง๊ณ ์คํ์ ์ ๊ฐํฉ๋๋ค. ์ฒซ ๋ฒ์งธ ์๋๋ฆฌ์ค๋ RNNs(Teacher Model)๊ณผ Transformers(Student Model)๋ฅผ, ๊ทธ๋ฆฌ๊ณ ๋ ๋ฒ์งธ ์๋๋ฆฌ์ค๋ CNNs(Teacher Model)๊ณผ MLPs(Student Model)๋ฅผ ๋น๊ตํฉ๋๋ค.
๋ณธ ์ฐ๊ตฌ๋ (1) ์ ๋ง ์ ์ ๋ชจ๋ธ๋ค์ด ๊ฐ์ง๊ณ ์๋ Inductive Bias๊ฐ ์ผ๋ง๋ ์ ์๋ฏธํ๊ฐ๋ฅผ ๋ณด์ฌ์ฃผ๊ฐ, (2) ์ ์ ๋ชจ๋ธ์๊ฒ ์ง์์ ์ ์ ๋ฐ์ ํ์ ๋ชจ๋ธ์ด ์ ๋ง ์ ์ ๋ชจ๋ธ๊ณผ ์ ์ฌํ ํ์ต์ ๊ฒฐ๊ณผ๋ฌผ์ ๋ณด์ฌ์ฃผ๋ ๊ฐ ๋ฅผ ๋ณด์ฌ์ฃผ๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ์คํ์ ์งํํ์์ต๋๋ค.
์ด๋ฒ ํฌ์คํ ์์๋ ๋๋ฒ์งธ ์๋๋ฆฌ์ค(CNNs vs MLPs)์ ๋ํด ๋ค๋ค๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
์ปดํจํฐ ๋น์ ์ ๋ํด ๊ด์ฌ์๋ ๋ถ๋ค์ ์ ๋งํ๋ฉด ๋ค์ด๋ดค์ ์ฉ์ด๊ฐ ๋ฐ๋ก CNN์ผ ํ
๋ฐ์. CNN์ Convolutional Neural Network์ ์ค์๋ง๋ก, ํ๊ธ๋ก๋ ํฉ์ฑ๊ณฑ์ ๊ฒฝ๋ง์ด๋ผ๊ณ ๋ ํฉ๋๋ค. ๋ชจ๋ธ์ ์ด๋ฏธ์ง๊ฐ ๋ค์ด์ค๊ฒ ๋๋ฉด Convolution Layer
๊ณผ Pooling Layer
๋ฅผ ํตํด ์ด๋ฏธ์ง์ ํน์ง๋ค(Features)์ ์ถ์ถํ๊ณ , ์ถ์ถ๋ ํน์ง๋ค์ Fully Connected Layer
์ ํต๊ณผ์์ผ ์ฃผ์ด์ง task๋ฅผ ์ํํ๊ฒ ๋ฉ๋๋ค. ๊ฐ๋จํ๊ฒ๋ง ์ดํด๋ณผ๊น์?
Convolution Layer์ ์๋ ๊ทธ๋ฆผ์ฒ๋ผ Window(Kernel)๊ฐ ์ด๋ฏธ์ง๋ฅผ ์ด๋ํ๋ฉด์ ๊ฐ๊ฐ์ ๊ฒน์ณ์ง๋ ํฝ์ ๊ณผ์ ๊ณฑ์ ๋ํ ๋ํ๋ ์ฐ์ฐ(Convolution, ํฉ์ฑ๊ณฑ)์ ์ํํ๋ Layer์ ๋๋ค. ์ด๋ Window์ ๊ฐ๋ค์ ๋ชจ๋ธ์ด ํ์ตํ๊ฒ ๋๋ฉฐ, Window์ ์ญํ ์ ๋ฐ์ดํฐ(์ด๋ฏธ์ง)์ ํน์ง์ ๋งต ํํ์ธ Feature Map(๋๋ Activation Map)์ผ๋ก ์ถ๋ ฅํด์ฃผ๋ ์ญํ ์ ๋๋ค.
Pooling Layer๋ ์์ Convolution Layer์ ์ถ๋ ฅ ๊ฐ(Feature Map)๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ Feature Map์ ํฌ๊ธฐ๋ฅผ ์ค์ด๊ฑฐ๋ ํน์ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์กฐํ๋ ์ฉ๋๋ก ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค. ์ด๋ ์ํํ๊ฒ ๋๋ ์ฐ์ฐ์ Pooling(ํ๋ง) ์ฐ์ฐ์ด๋ผ๊ณ ํ๋๋ฐ, ์ด๋ ์ ์ฌ๊ฐ ํ๋ ฌ(Filter)์ ํน์ ์์ญ ์์ ๊ฐ์ ๋ํฏ๊ฐ์ ๊ตฌํ๋ ๋ฐฉ์์ผ๋ก ๋์ํฉ๋๋ค. Pooling์๋ Max Pooling, Average Pooling, Min Pooling์ด ์กด์ฌํฉ๋๋ค. ์ด๋ฆ์์ ์ ์ ์๋ค์ํผ Max Pooling์ ๊ฐ์ฅ ํฐ ๊ฐ์ด, Average Pooling์ ํ๊ท ๊ฐ์ด, Min Pooling์ ๊ฐ์ฅ ์์ ๊ฐ์ด ์ด์๋จ๋๋ก ํ๋ ๊ฒ์ ๋๋ค.(์ฌ์ง์ฐธ๊ณ )
Multi Layer Perceptron์ ๋ํด ๋ ผํ๊ธฐ ์ ์ ๋จผ์ Perceptron์ ๋ํด ๋ค๋ฃฐ ์ ๋ฐ์ ์๊ฒ ์ฃ ? ํผ์ ํธ๋ก (Perceptron)์ Frank Rosenblatt๊ฐ 1957๋ ์ ์ ์ํ ์ด๊ธฐ ํํ์ ์ธ๊ณต ์ ๊ฒฝ๋ง์ผ๋ก ๋ค์์ ์ ๋ ฅ์ผ๋ก๋ถํฐ ํ๋์ ๊ฒฐ๊ณผ๋ฅผ ๋ด๋ณด๋ด๋ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ์ด๋ ์ ๋ ฅ๊ฐ์ ์ ํ๊ฒฐํฉ ๊ฐ์ ๊ตฌํ๊ณ , ๊ทธ ๊ฐ์ด 0(threshold)๋ณด๋ค ํฐ์ง๋ฅผ ์ฌ๋ถ๋ก ๋ถ๋ฅํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
๋จ์ธต ํผ์
ํธ๋ก ์ OR/AND/XOR ์ค OR๊ณผ AND๋ฌธ์ ๋ฅผ ํ ์ ์์์ง๋ง XOR๋ฌธ์ ๋ ํ ์๊ฐ ์์์ต๋๋ค.
๋จ์ธต ํผ์ ํธ๋ก ์ผ๋ก๋ ํด๊ฒฐํ ์ ์์ ์ด์ ํด๊ฒฐ์ฑ ์ผ๋ก ์ ์๋ ๋ฐฉ๋ฒ์ด ๋๊ฐ์ ํผ์ ํธ๋ก ์ ๊ฒฐํฉํ ์ค ํผ์ ํธ๋ก (2-layer Perceptron)์ธ๋ฐ, ์ด๋ ๊ฒ ์ธต์ด ์ฌ๋ฌ๊ฐ์ธ ํผ์ ํธ๋ก ์ ๋ค์ธต ํผ์ ํธ๋ก (Multilayer Perceptron)์ด๋ผ๊ณ ์นญํฉ๋๋ค. ์ด๋ฌํ ๋ค์ธต ํผ์ ํธ๋ก ์ด ์ฐ๋ฆฌ๊ฐ ์๊ณ ์๋ ์ธ๊ณต์ ๊ฒฝ๋ง(ANN, Artificial Neural Network)๊ฐ ๋๊ฒ๋ฉ๋๋ค.
Source : https://blog.goodaudience.com/artificial-neural-networks-explained-436fcf36e75
์ธ๊ณต ์ ๊ฒฝ๋ง(๋๋ ๋ค์ธต ํผ์ํธ๋ก )์ ์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ์
๋ ฅ์ธต, ์๋์ธต, ์ถ๋ ฅ์ธต์ผ๋ก ๊ตฌ์ฑ์ด ๋ฉ๋๋ค. ์
๋ ฅ์ธต
์ ์
๋ ฅ๋ณ์์ ๊ฐ์ด ๋ค์ด์ค๋ ์ธต, ์๋์ธต
์ ๋ค์ ๋
ธ๋ ๋๋ ์ธต๋ค์ด ํฌํจ๋ ์ ์์ผ๋ฉฐ ๋ฐ์ดํฐ๋ก๋ถํฐ ์จ๊ฒจ์ง ์๋ฏธ(ํน์ง)์ ํ์ตํ๋ ์ธต, ์ถ๋ ฅ์ธต
์ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋ ์ธต์
๋๋ค.
๋ด๋ด๋คํธ์ํฌ์๋ ์๊ฐํ์ง ์๊ณ ๋์ด๊ฐ ์ ์๋ ์ค์ํ ์ฉ์ด๋ค์ด ์กด์ฌํ๋ ๋ฐ์. ์ด๋ ์๋์ ๊ฐ์ต๋๋ค.
ํ์ฑํ ํจ์์ ์ข ๋ฅ๋ ๋ค์ํ์ง๋ง ๋ํ์ ์ด๊ณ , ๊ณ ์ ์ ์ธ ๋ช ๊ฐ์ง๋ง ์๊ฐํ์๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
์์ ๊ฐ๋ ๋ค์ ๋ณด๋ฉด์ ๋์น์ฑ์ ๋ถ๋ค๋ ์ด๋ฏธ ๊ณ์๊ฒ ์ง๋ง, CNN์ MLP์ ๋นํด ๊ตฌ์กฐ์ ์ผ๋ก ์ ์ฝ(Inductive Bias)์ ๋ ๋ง์ด ๋๊ฒ ๋ฉ๋๋ค. ์๋ ๊ทธ๋ฆผ์ ๋ณด์๋ฉด CNN์ ๊ฒฝ์ฐ Window๊ฐ ์ง๋๋ค๋๋ฉด์ input์ Fixed๋ Window Weight๋ค์ ๋์ผํ๊ฒ ๊ณฑํด์ฃผ๋ ๋ฐ๋ฉด, MLP์ ๊ฒฝ์ฐ๋ input์ ๋ค๋ฅธ Weight๋ค์ ๊ณฑํด์ฃผ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
CNN์ Inductive Bias๋ก๋ ํฌ๊ฒ ๋๊ฐ์ง๋ฅผ ๋ค ์ ์์ต๋๋ค. ์ด๋ ๋ฐ๋ก Translation Invariance์
Scale Invariance์
๋๋ค. Translation Invariance๋ ๋ฌผ์ฒด๋ฅผ ์ด๋(translate) ์์ผ๋ ์ถ๋ ฅ ๊ฐ์ธ Logit ๊ฐ์ ๋ณํ์ง ์๋๋ค๋ ๊ฒ์ด๊ณ , Scale Invariance๋ ๋ฌผ์ฒด์ ์ค์ผ์ผ(scale)์ ์๋ฌด๋ฆฌ ๋ฐ๊พธ์ด๋ ์ถ๋ ฅ ๊ฐ์ธ Logit ๊ฐ์ ๋ณํ์ง ์๋๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
์ด๋ฌํ Inductive Bias๋ CNN์ ์๋์ ํน์ฑ๋ค๋ก ์ธํด Translation๊ณผ Scaling์ ์ํํด๋ ๊ฐ์ด ๋ณด์กดํ ์ ์๊ฒ๋ฉ๋๋ค. ๋ฐ๋ก Convolution ์ฐ์ฐ, Pooling ์ฐ์ฐ, ๊ทธ๋ฆฌ๊ณ Cross-Channel Pooling ์ฐ์ฐ์ ํตํด ์ด๋ฅผ ๋ณด์กดํ๊ฒ ๋ฉ๋๋ค. ์์ ๋ ์ฐ์ฐ์ ์์ Convolutional Neural Nets (CNNs) ํํธ์์ ์๊ฐ๋๋ ธ์ผ๋, ๊ฐ๋จํ๊ฒ Cross-Channel Pooling ์ฐ์ฐ์, ํ ์ฑ๋ ๋ด์์ Pooling ์ฐ์ฐ์ ์ํํ๋ ๊ธฐ๋ณธ Pooling๊ณผ๋ ๋ค๋ฅด๊ฒ, ์ฌ๋ฌ ๊ฐ์ ์ฑ๋ ์์์ ์ด๋ฃจ์ด์ง๋ฉฐ Channel๊ฐ์ Pooling์ ์ํํ ๊ฒ์ผ๋ก ์ดํดํ์๋ฉด ๋ ๊ฒ ๊ฐ์ต๋๋ค.(์๋ ๊ทธ๋ฆผ ์ฐธ๊ณ )
๋ณธ ์๋๋ฆฌ์ค ์ญ์ ๋ง์ฐฌ๊ฐ์ง๋ก CNN๋ชจ๋ธ์ด MLP๋ชจ๋ธ๋ณด๋ค Translation๊ณผ Scaling์ ๋ ์ข์ Inductive Bias๋ฅผ ๊ฐ์ง๊ณ ์๋๊ฐ, ๊ทธ๋ฆฌ๊ณ ๊ณผ์ฐ CNN๋ชจ๋ธ์ Teacher๋ก, MLP๋ชจ๋ธ์ Student๋ก Knowledge Distillation์ ์ํํ์์ ๋ ์ข์ ์ฑ๋ฅ์ด ๋์ค๋๊ฐ๋ฅผ ์คํ์ ํตํด ๋ณด์ด๊ณ ์ ํ์์ต๋๋ค.
๋ชจ๋ธ Training์ ์ํด MNIST ๋ฐ์ดํฐ ์ ์ ์ฌ์ฉํ์๊ณ , Inference ์ฑ๋ฅ ํ์ธ์ ์ํด ๊ธฐ๋ณธ MNIST ๋ฐ์ดํฐ ๋ฟ๋ง ์๋๋ผ MNIST-C(Corrupted) ์ค Traslated์ Scaled ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์์ต๋๋ค.
๋จผ์ , ๊ฐ๊ฐ CNN๋ชจ๋ธ๊ณผ MLP๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ์ธํด๋ณด์๋๋ฐ์. Original MNIST ๋ฐ์ดํฐ์ ๋ํด์๋ CNN๊ณผ MLP ๋ ๋ค ์ข์ ์ฑ๋ฅ์ ๋ํ๋ด๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ํ์ง๋ง, ๊ทธ ์ธ์ Translated์ Scaled MNIST-C ๋ฐ์ดํฐ์์๋ ์ฑ๋ฅ(Accuracy, Expected Calibration Error) ์ฐจ์ด๊ฐ ํฌ๊ฒ ๋ฒ์ด์ง ๋ฟ๋ง ์๋๋ผ, ๊ฐ์ ๋ถ์ฐ๋๊ฐ ํฐ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ์ด๋ฅผ ํตํด CNN์ด ๊ฐ์ง ๊ตฌ์กฐ์ ์ธ ํธํฅ(Inductive Bias)๋ฅผ ํตํด Translation๊ณผ Scaling์ MLP๋ณด๋ค ๊ฐํ๋ค๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
์ด์ Inductive Bias๊ฐ ๋ ํฐ CNN๋ชจ๋ธ์ Teacher๋ชจ๋ธ๋ก ํ๊ณ , ์์ MLP๋ชจ๋ธ์ Student๋ชจ๋ธ๋ก Knowledge Distillation(KD)์ ์ํํ์ ๋ ์๋์ ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ ๋์๋๋ฐ์. ์ด๋ฅผ ๋ณด๋ฉด ์์ ์ฑ๋ฅํ์ ๋นํด์ MLP๋ชจ๋ธ๋ค์ ์ฑ๋ฅ์ด ๋ง์ด ์์นํ ๊ฒ์ ํ์ธํ ์ ์๊ณ , ๋ถ์ฐ๋ ๊ธฐ์กด MLP๋ณด๋ค CNN-MLP๊ฐ ์์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
์์์ ์ ๋์ ์ผ๋ก Inductive Bias๊ฐ ์ ๋ง ์ ๋ฌ๋์ด ๋ชจ๋ธ ์ฑ๋ฅ์ด ํฅ์๋์๋๊ฐ๋ฅผ ํ์ธํ๋ค๋ฉด, ์ด๋ฒ์๋ Multi-dimensional Scaling(MDS)๋ฅผ ์ด์ฉํ์ฌ ๊ณ ์ฐจ์์ Feature Map์ ์๊ฐํํด ๋ณด์์ต๋๋ค. ์ฌ๊ธฐ์๋ ๋ง์ฐฌ๊ฐ์ง๋ก MLP์ ๋ถ์ฐ์ด CNN๋ณด๋ค ํฌ๊ณ , CNN์ Teacher๋ก ํ์ต์ ์ํํ๋ฉด ๋ถ์ฐ์ด ๊ฐ์ํ๋ฉฐ ์๊ฐํ๋๋ ์์น๊ฐ ์ ์ CNN์ ๊ฐ๊น์์ง๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๋ ํ๊ฐ์ง ํฅ๋ฏธ๋ก์ด ์๊ฐํ๋ฅผ ์ํํ์๋๋ฐ์. ์๋ ๊ทธ๋ฆผ์ ๋ณด์๋ฉด ๊ฐ๊ฐ (a),(b),(c)๋ MLP, CNN, CNN->MLP๊ฐ epoch๋ณ๋ก ํ์ต๋์ด ๊ฐ๋ ๊ฒ์ MDS๋ก ์๊ฐํ ํ ๊ฒ์ ๋๋ค. (a)์ MLP๋ฅผ ๋ณด๋ฉด ์ค๊ตฌ๋๋ฐฉ์ ์ผ๋ก ํด๊ฐ ์๋ ดํ๋ ๊ฒ์ ๋ณผ ์ ์๊ณ , (b)์ CNN์ ๋ณด๋ฉด ๋ญ๊ฐ ์์ํ ํน์ ํ ๋ฐฉํฅ์ผ๋ก ์๋ ดํด๊ฐ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. (c)๋ CNN(teacher)->MLP(student)์ธ๋ฐ ๊ธฐ์กด MLP(a)์ ํ์ต ์์๊ณผ๋ ๋ค๋ฅด๊ฒ ๊ท์น์ ์ผ๋ก ํด๊ฐ ์๋ ดํด๊ฐ๋ ๊ฒ์ ํ์ธํ ์ ์์์ต๋๋ค.
๋ณธ ๋ ผ๋ฌธ์ Knowledge Distillation(KD)์ ๊ฒฝ๋ํ ํจ๊ณผ ์ด์ธ์ ๋ค๋ฅธ ๋ชจ๋ธ๋ค๊ณผ ํจ๊ป ์ฐ์ผ ์ ์๋ค๋ ์ ์ ์ฐฉ์ํ์ฌ inductive bias๋ฅผ ๊ณผ์ฐ KD๋ฅผ ํตํด ์ ๋ฌํ ์ ์๋ ๊ฐ๋ฅผ ์คํ์ ํตํด ์๋ ์์๋๋ก ์ ์ฆํด๋ณด์ด๊ณ ์ ํ์์ต๋๋ค.
์ฒซ์งธ, ํน์ task์ ์ ๋นํ inductive bias๋ฅผ ๊ฐ๋ ๊ฒ์ด ์ ๋ง ์ค์ํ ๊ฐ๋ฅผ ์คํ์ ํตํด ์ ์ฆํ์์ต๋๋ค.
๋์งธ, ํด๋น ๋ชจ๋ธ์ด ์ ๋นํ inductive bias๋ฅผ ๊ฐ๊ณ ์๋ค๋ฉด, inductive bias๊ฐ ๋ถ์กฑํ ๋ค๋ฅธ ๋ชจ๋ธ๋ค์๊ฒ ํ์ต์ ๊ฐ์ด๋๋ผ์ธ์ ์ ๊ณตํด์ค ์ ์์์ ์คํ์ ํตํด ์ ๋์ , ๊ทธ๋ฆฌ๊ณ ์ ์ฑ์ ์ผ๋ก ์ ์ฆํ์์ต๋๋ค.
ํฅ๋ฏธ๋กญ๊ฒ ์ฝ์๋ ๋ ผ๋ฌธ์ ๋ฌด๋ ค 3์ฐจ๋ก์ ๊ฑธ์ณ์ ์์ธํ๊ฒ ํ๋ฒ ๋ค๋ฃจ์ด๋ณด์๋๋ฐ์. ์ด๋ ๊ฒ ์์ธํ๊ฒ ๋ ผ๋ฌธ์ ๋ฆฌ๋ทฐํ๋ ๋ฐฉ์์ ์ด๋ ๋์ง ๊ตฌ๋ ์๋ค์ ์๊ฒฌ ๋ํ ๊ถ๊ธํ๋ค์ ๐
๊ธด ๊ธ ์ฝ์ด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค ^~^