ICLR 2017 conference paper๋ก 2023๋
11์ ๊ธฐ์ค 3000ํ ๋๋ ์ธ์ฉ ํ์๋ฅผ ๊ฐ์ง๊ณ ์๋ ์ค์ํ ๋
ผ๋ฌธ์ ์ฝ๊ณ ํด์ํด๋ณด์๋ค.
๋
ผ๋ฌธ ๋ฆฌ๋ทฐ๋ฅผ ์ด๋ป๊ฒ ํ๋ฉด ์ ํ ์ ์๋ ๊ฑด์ง ์์ง๋ ์์ ๋ชจ๋ฅด๊ฒ ์ง๋ง, ์ผ๋จ ์ฝ์ผ๋ฉด์ ์๊ฐ์ ์ ๋ฆฌํด๋ณธ๋ค๋ ์์ผ๋ก ํด๋ณด๋ ค๊ณ ํ๋ค. ๋ชฉํ๋ ๋ถ์บ ํ๋ ์ค์ ์ง์น์ง ์๊ณ ๊พธ์คํ ํ๋ ๊ฒ์ด๋ค. ๊พธ์คํ ํด๋ณด์๋ ๋ชฉํ๋ก ๋ถ์บ ์์ ๋
ผ๋ฌธ ๋ฆฌ๋ทฐ ์คํฐ๋๋ ํ๋ ๋ง๋ค์ด์ ๊ฐ์ด ํ๊ณ ์๋ค.
๋ฅ๋ฌ๋ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก SGD(Stochastic Gradient Descent) ๋ง์ด ์ฌ์ฉ๋จ.
โ ํฐ ๋ฐฐ์น ์ฌ์ด์ฆ ์ฌ์ฉ์, ๋ชจ๋ธ ์ฑ๋ฅ์ ์ ํ ๋ณด์(์ผ๋ฐํ์ ๋ฅ๋ ฅ์ ๊ธฐ์ค์ผ๋ก ๋ณด๋ฉด)
โ ๊ทธ ์ด์ ๋ ํฐ ๋ฐฐ์น ์ฌ์ด์ฆ๋ ํ์ต๊ณผ ํ
์คํธ ๋ชจ๋ธ์์ sharp minimizer๋ก ์๋ ดํ๋๋ฐ, sharp minimizer๋ ์ผ๋ฐํ ์ฑ๋ฅ์ ์ ํ์ํด
โ ๋ฐ๋๋ก ์์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ flat minimizer๋ก ์๋ ดํ๋๋ฐ, ๊ทธ ์ด์ ๋ ๊ธฐ์ธ์ด(gradient)๋ฅผ ์ถ์ ํ๋ ๋ฐ์ ์๋ ๋ด์ฌ๋ ๋
ธ์ด์ฆ ๋๋ฌธ์ผ๋ก ๋ด
์ด ์์์ Batch Size Gradient ๋ฐฉ๋ฒ๋ก ์ ์์์ผ๋ก ์๊ฐํ๋๋ฐ,
์ข๋ ์ ํํ๊ฒ ํด์ํ๋ฉด MGD(Mini Batch Gradient Descent)๋ฅผ ํํํ๋ค.
(Batch Size ๊ฐ์๋ก loss๋ฅผ ๋๋๊ธฐ ๋๋ฌธ์)
๊ทธ๋์, ๋ณธ ๋
ผ๋ฌธ์
1) ์ต์ ์ ํผํฌ๋จผ์ค ๋ชจ๋ธ์ ์ ์ํ๊ณ ,
2) ํฐ ๋ฐฐ์น ์ฌ์ด์ฆ ํ์ต์ ๊ทน๋ณตํ๋ ๋ฐฉ๋ฒ๋ก ์ ์ ์ํ๋ค๊ณ ํจ.
1) LB ๋ฐฉ๋ฒ์ ๋ชจ๋ธ์ ๊ณผ์ ํฉํด์
2) LB ๋ฐฉ๋ฒ์ ์์ฅ์ (saddle point )์ ๋ ๋งค๋ ฅ์ ์
3) LB ๋ฐฉ๋ฒ์ SB ๋ฐฉ๋ฒ์ ํ์์ ํน์ฑ์ด ๋ถ์กฑํ๋ฉฐ ์ด๊ธฐ ์ง์ ์ ๊ฐ์ฅ ๊ฐ๊น์ด ์ต์ํ ์ฅ์น๋ฅผ ํ๋ํ๋ ๊ฒฝํฅ์ด ์์ด์
4) SB ๋ฐ LB ๋ฐฉ๋ฒ์ ์ง์ ์ผ๋ก ๋ค๋ฅธ ์ต์ํ ๋๊ตฌ๋ก ์๋ ด๋๊ธฐ์
์ผ๋ฐํ ๋ฅ๋ ฅ์ ๋ถ์กฑ์ ์์ธ์ LB๋ sharp minimizer, SB๋ flat minimizer์ด๊ธฐ ๋๋ฌธ์ด๋ผ๊ณ ํ๋ค.
์๋์ sharp minimizer์ flat minimizer ๊ทธ๋ํ์์ y์ถ์ loss fuction value๋ฅผ ์๋ฏธํ๊ณ ์์ ์์์์ ๋ฅผ ์๋ฏธํ๋ค. ์ฆ, loss function์ "ํํํ(flat)" ์ต์๊ฐ์ด SB(์์ ๋ฐฐ์น์ฌ์ด์ฆ)๋ก ์ป์ด์ง ๊ฒฐ๊ณผ๋ค.
โ sharp minimizer ์ x์ ์์ ๋ณํ์๋ ํจ์๊ฐ ๊ธ๊ฒฉํ๊ฒ ์ฆ๊ฐํ๊ณ , flat minimizer์ ๋น๊ต์ x์ ์์ ๋ณํ์์ ์ฒ์ฒํ ๋ณํํ๋ค. flat minimizer ๋ฎ์ ์ ๋ฐ๋๋ผ๋ฉด sharp minimizer์ ๋์ ์ ๋ฐ๋๋ฅผ ๋ณด์. ์ฌ๊ธฐ์ sharp minimizer์ ํฐ ๋ฏผ๊ฐ๋(์ ๋ฐ๋)๋ ์๋ก์ด ๋ฐ์ดํฐ์ ๋ํ ๋ชจ๋ธ์ ์ผ๋ฐํ ๋ฅ๋ ฅ์ ์ ์ข์ ์ํฅ์ ๋ผ์นจ
- ๋ชจ๋ ์คํ์์ ๋๊ท๋ชจ ๋ฐฐ์น ์คํ์์๋ ํ๋ จ ๋ฐ์ดํฐ์ 10%๋ฅผ ๋ฐฐ์น ํฌ๊ธฐ๋ก ์ฌ์ฉํ๊ณ , ์๊ท๋ชจ ๋ฐฐ์น ์คํ์์๋ 256๊ฐ์ ๋ฐ์ดํฐ ํฌ์ธํธ๋ฅผ ์ฌ์ฉํจ.
โ ๋ชจ๋ ๋คํธ์ํฌ์์ ๋ ๊ฐ์ง ์ ๊ทผ ๋ฐฉ์ ๋ชจ๋ ๋์ ํ๋ จ ์ ํ๋๋ก ์ด์ด์ก์ง๋ง ์ผ๋ฐํ ์ฑ๋ฅ์๋ ์๋นํ ์ฐจ์ด๊ฐ ์์์ ๊ด์ฐฐํ ์ ์์ต๋๋ค.
(We emphasize that the generalization gap is not due to over-fitting or over-training as commonly observed in statistics)
์์ ๊ฐ์ด Sharpness๋ผ๋ metric์ ์๋กญ๊ฒ ์ ์ํ๋ค.
์๋ฒฝํ ์์์ด ์ดํด๋์ง๋ ์์ง๋ง, ์ ๋นํ boundary ()์์์ maximization of loss function ์ ์ธก์ ํ๋ฉด์ ๋ฏผ๊ฐ๋?๋ฅผ ๋ณธ๋ค.
์์ sharpness ํ๋ฅผ ๋ณด๋ฉด, ์ ๋ฐ๋ผ ๊ตฌํ ๋ชจ๋ธ ๋ณ SB, LB์ ๋ฐ๋ฅธ ๋ฏผ๊ฐ๋๋ฅผ ๋ณผ ์ ์๋ค.
- SB (์์ ๋ฐฐ์น) โ ๋ ๋ฎ์ sharpness๊ฐ
- LB (ํฐ ๋ฐฐ์น) โ ํฐ sharpness๊ฐ
์คํ ๊ฒฐ๊ณผ์์๋ ์์ ์๋ฏ์ด, ํฐ ๋ฐฐ์น ๋ฐฉ๋ฒ(LB)์ผ๋ก ์ป์ ์๋ฃจ์ ์ด ํ๋ จ ํจ์์ ๋ ํฐ ๋ฏผ๊ฐ๋ ์ง์ ์ ์ ์ํ๋ค๋ ๊ด์ ์ ํ์ธํ ์ ์๋ค.
๐ค ๊ทธ๋ ๋ค๋ฉด, LB method์์๋ generalization gap์ ์ด๋ป๊ฒ ๊ทน๋ณตํด์ผํ๋?
: data augumentation(๋ฐ์ดํฐ ์ฆ๋), conservative training(๋ณด์์ ํ๋ จ), adversarial(์ ๋์ ํ๋ จ) ๊ฐ์ ์ ๊ทผ๋ค์ด generalization gap์ ๊ทน๋ณตํ๋๋ฐ ๋์์ด ๋์ง๋ง, ํ์ง๋ง ์ฌ์ ํ ์๋์ ์ผ๋ก sharp minimizer๋ก ์ด์ด์ง๋ฉฐ ๋ฌธ์ ๋ฅผ ์์ ํ ํด๊ฒฐํ์ง๋ ๋ชปํ๋ค.
์๋์ Sharpness์ Accuracy ๊ฒฐ๊ณผ ๊ฐ์ batch size์ ๋ฐ๋ผ ํ์ธํด ๋ณด๋ฉด, small batch ๋ชจ๋ธ์ด ๋ ์ฑ๊ณต์ ์ด์์์ ํ์ธํ ์ ์๋ค.
"์์คํํธ(Warm-Starting) ์คํ"์ ์๋ํด๋ณด๊ธฐ๋ ํ์๋๋ฐ,
๋ณธ ๋
ผ๋ฌธ์์๋ piggybacked(or warm-started) large-batch solution์ด๋ผ๊ณ ๋ถ๋ฅธ๋ค.
Step.1 ) batch size 256์ผ๋ก ํ์ฌ(์์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ก) ADAM๋ฅผ ์ฌ์ฉํ์ฌ 100 epoch๋ก ํ๋ จ
Step.2 ) ๊ฐ epoch ์ดํ์ iterate(๋ฐ๋ณต)์ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ง => epoch๋ง๋ค ํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ฅํด๋
Step.3 ) ์ด๋ฌํ 100ํ์ iterate(๋ฐ๋ณต) ์งํ๋ฅผ ๊ฐ๊ฐ ์์์ ์ผ๋ก ํ์ฌ LB(ํฐ ๋ฐฐ์น ์ฌ์ด์ฆ) ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ 100 iteration ๋์ ํ๋ จ
Step.4) 100๊ฐ์ ํผ๊ธฐ๋ฐฑ(๋๋ ์ ์คํํธ) ๋๊ท๋ชจ ๋ฐฐ์น ๋ชจ๋ธ๋ง ๊ฒฐ๊ณผ ํ์ธ
์๋ ๊ทธ๋ฆผ 5์๋ ์๊ท๋ชจ ๋ฐฐ์น ๋ฐ๋ณต์ ํ
์คํธ ์ ํ๋์ ํจ๊ป ์ด๋ฌํ ๋๊ท๋ชจ ๋ฐฐ์น ์๋ฃจ์
์ ํ
์คํธ ์ ํ๋ ๋ฐ sharpness(์ ๋ช
๋)๊ฐ ํ์๋์ด ์๋ค. ๋ช ๋ฒ์ ์ด๊ธฐ epoch ๋ง์ผ๋ก ์ ์คํํธํ๋ฉด LB ๋ฐฉ๋ฒ์ด ์ผ๋ฐํ ๊ฐ์ ์ ๊ฐ์ ธ์ค์ง ์๋๋ค. ์ฌ์ง์ด sharpness(์ ๋ช
๋)๋ ๋๊ฒ ์ ์ง๋๋ค.
๋ฐ๋ฉด, ํน์ ํ์์ ์ ์คํํธ ์ดํ์๋ ์ ํ๋๊ฐ ํฅ์๋๊ณ ๋๊ท๋ชจ ๋ฐฐ์น ๋ฐ๋ณต์ ์ ๋ช
๋๊ฐ ๋จ์ด์ง๋๋ค. ์ด๋ ๋ถ๋ช
ํ SB ๋ฐฉ๋ฒ์ด ํ์ ๋จ๊ณ๋ฅผ ์ข
๋ฃํ๊ณ flat minimizer๋ฅผ ๋ฐ๊ฒฌํ์ ๋ ๋ฐ์ํ๋ค. ๊ทธ๋ฐ ๋ค์ LB ๋ฐฉ๋ฒ์ด ์ด๋ฅผ ํฅํด ์๋ ดํ ์ ์์ด ํ
์คํธ ์ ํ๋๊ฐ ํฅ์๋ฉ๋๋ค.
piggybacked(or warm-started) large-batch solution์ ๋ ๋์ ๋ฐฉ๋ฒ์ผ ์ ์์!
(์ผ๋ถ ์๋ ด ํ์ ์ ์ญ ์ต์๊ฐ์ ์ฐพ๊ธฐ ๋๋ฌธ)
์์ค ํจ์์ ๊ฐ์ด ๋ ํฐ ๊ฒฝ์ฐ, ์ฆ ์ด๊ธฐ์ ๊ทผ์ฒ์์๋ SB ๋ฐ LB ๋ฐฉ๋ฒ์ด ๋น์ทํ ์ ๋ช ๋ ๊ฐ์ ์ฐ์ถํ๋ค.
ํ์ง๋ง, ์์ค ํจ์๊ฐ ๊ฐ์ํจ์ ๋ฐ๋ผ LB ๋ฐฉ๋ฒ์ sharpness๊ฐ ๊ธ๊ฒฉํ ์ฆ๊ฐํ๋ ๋ฐ๋ฉด, SB ๋ฐฉ๋ฒ์ ๊ฒฝ์ฐ sharpness๋ ์ฒ์์๋ ์๋์ ์ผ๋ก ์ผ์ ํ๊ฒ ์ ์ง๋ ๋ค์ ๊ฐ์ํ์ฌ ํ์ ๋จ๊ณ์ ์ด์ด flat minimizer๋ก ์๋ ดํจ์ ๋ณด์ธ๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก,
โ
sharp minimizer ๋ก์ ์๋ ด์ ํฐ ๋ฐฐ์น ๋ชจ๋ธ(LB)์ ์ผ๋ฐํ๊ฐ ์ ๋๋ก ์ด๋ฃจ์ด์ง์ง ์๊ฒ ๋ง๋ ๋ค.
โ
LB์ ์ผ๋ฐํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ ์ data augmentation(๋ฐ์ดํฐ ์ฆ๋), conservative training(๋ณด์์ ํ๋ จ) and robust optimization(๊ฐ๋ ฅํ ์ต์ ํ)๊ฐ ๋ฐฉ๋ฒ์ผ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ์ง๋ง, ์ด๋ฌํ ์ ๋ต์ ๋ฌธ์ ๋ฅผ ์๋ฒฝํ ํด๊ฒฐํด์ฃผ์ง ๋ชปํ๋ค. ๋ฌผ๋ก LB๋ชจ๋ธ์ ์ผ๋ฐํ ์ฑ๋ฅ์ ํฅ์์ํค๊ธฐ๋ ํ์ง๋ง ์ฌ์ ํ sharp minimizer๋ก ์ด๋๋ค.
โ
LB์ ์ผ๋ฐํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ ์ dynamic sampling(batch size๋ฅผ ์ ์ง์ ์ผ๋ก ํค์ฐ๋ ๊ฒ)๋ ๊ณ ๋ คํด๋ณด์๋ค. warm-starting experiment๋ฅผ ์ค์ ๋ก ์ํํด๋ณด์๊ณ , ์ด๋์ ๋ LB ๋ชจ๋ธ์์ ๊ด์ฐฎ์ ๋ฐฉ์์ด๋ผ๊ณ ๋ด.
โ
๊ธฐ์กด ์ฐ๊ตฌ ๊ฒฐ๊ณผ์์ ๊ฐ์ ํ์, ๋ฅ ๋ฌ๋ ๋ชจ๋ธ์ ์์ค ํจ์์ ๋ง์ local minimizer๊ฐ ํฌํจ๋์ด ์์ผ๋ฉฐ ์ด๋ฌํ minimizer ์ค ๋ค์๊ฐ ์ ์ฌํ ์์ค ํจ์ ๊ฐ์ ํด๋นํ๋ค๋ ๊ฒ์ ๋ณด์ฌ์ค๋ค. ๋ณธ ์คํ์์ sharp minimizer๊ณผ flat minimizer ๋ชจ๋ ๋งค์ฐ ์ ์ฌํ ์์ค ํจ์ ๊ฐ์ ๊ฐ๊ธฐ ๋๋ฌธ์ ๋ณธ ์คํ ๊ฒฐ๊ณผ๋ ์์ ์ฐ๊ตฌ ๊ฒฐ๊ณผ์ ๊ด์ฐฐ๊ณผ ์ผ์นํ๋ค.
โ
์ฆ, Small Batch Size๊ฐ Generalizatioin Performance์์ ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ธ๋ค
์ด๋ ๊ฒ ์ ๋ฆฌ ๋ ๐ฅน