Contributions

- ์ฑ๋๋ณ ๋ณด๊ฐ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ์ฌ Source batch statistics(CBN)์ Test batch statistics(TBN)๋ฅผ ๊ฒฐํฉํ๋ฉฐ, ์๋ก์ด ๋๋ฉ์ธ์ ๋ํด ์ ์ํ๋ฉด์๋ ์์ฒ์ง์์ ๋ณด์กดํ๋ TTN ๋ ์ด์ด๋ฅผ ์ ์ํ๋ค.
- ๊ธฐ์กด์ TTA ๋ฐฉ๋ฒ์ TTN์ ์ถ๊ฐํ๋ฉด ํ
์คํธ ๋ฐฐ์น ํฌ๊ธฐ(1์์ 200๊น์ง)์ ํญ๋์ ๋ฒ์์์ ์ฑ๋ฅ์ด ํฌ๊ฒ ํฅ์๋จ์ ๋ณด์ธ๋ค.
2. Methodology
2.1. Problem Setup
- ํ๋ จ ๋ฐ์ดํฐ์ ํ
์คํธ ๋ฐ์ดํฐ๋ฅผ DSโ,DTโ, ๊ฐ๊ฐ์ ๋ํ ํ๋ฅ ๋ถํฌ๋ฅผ PSโ,PTโ๋ผ ํ์.
- TTA์ Covariate Shift๋ PSโ๎ โ=PTโ where PSโ(yโฃx)=PTโ(yโฃx)์ด๋ค.
- ๋ชจ๋ธ fฮธโ๋ DSโ์ mini-batch BS={(xiโ,yiโ)}i=1โฃBSโฃโ๋ก ํ์ต๋์์ผ๋ฉฐ,
- ํ
์คํธ ์ค์๋ fฮธโ๋ ํ
์คํธ ๋ฐฐ์น BTโโผDTโ๋ฅผ ๋ง๋๊ฒ ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ TTA์ ๋ชฉํ๋ ๋ค๋ฅธ ๋ถํฌ๋ก๋ถํฐ ํ
์คํธ ๋ฐฐ์น๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ์ฒ๋ฆฌํ๋ ๊ฒ์ด๋ค.
์ข ๋ ์ค์ฉ์ ์ธ TTA๋ฅผ ์๋ฎฌ๋ ์ด์
ํ๊ธฐ ์ํด ์ฐ๋ฆฌ๋ ๋ ๊ฐ์ ๋ณํ๋ฅผ ์ค์ ์ ์ผ๋ก ๊ณ ๋ คํ๋ค.
- ๋ค์ํ ํ
์คํธ ๋ฐฐ์น์ ํฌ๊ธฐ โฃBTโฃ - ์์ ๋ฐฐ์น์ฌ์ด์ฆ๋ ์งง์ ์ง์ฐ์ผ๋ก ์ด์ด์ง๋ค.
- ์ฌ๋ฌ ๊ฐ, N ๊ฐ์ ๋๋ฉ์ธ DTโ={DT,iโ}i=1Nโ์ผ๋ก์ ์ ์
2.2. Test-Time Normalization Layer
- BN ๋ ์ด์ด์ ์
๋ ฅ์ zโRBCHW๋ผ ํ์. (๋ฐฐ์น ์ฌ์ด์ฆ B, ์ฑ๋ ์ C, ๋์ด์ ๋๋น H,W)
- z์ ํ๊ท ๊ณผ ๋ถ์ฐ์ ฮผ์ ฯ2์ด๊ณ , ๊ฐ๊ฐ์ ๋ค์๊ณผ ๊ฐ์ด ๋ํ๋๋ค.
ฮผcโ=BHW1โbโBโhโHโwโWโzbchwโ,ฯc2โ=BHW1โbโBโhโHโwโWโ(zbchwโโฮผcโ)2(1)
- BN๋ ์ด์ด์์ z๋ ๋จผ์ ฮผ์ ฯ2๋ก ํ์คํ๋ ๋ค, ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ฮณ,ฮฒโRC๋ก affine ๋ณํ๋๋ค.
- ํ์คํ๋ ํ์ฌ input batch statistics๋ฅผ ํ๋ จ ์ค์ ์ฌ์ฉํ๊ณ , ํ
์คํธ ์ค์๋ ์ถ์ ๋ source statistics ฮผsโ์ ฯs2โ๋ฅผ ์ฌ์ฉํ๋ค.
- ๋ํ domain shift๋ฅผ ๊ณ ๋ คํ๊ธฐ ์ํด, ํ์ต ๊ฐ๋ฅํ ๋ณด๊ฐ ๊ฐ์ค์น ฮฑโR์ ์ด์ฉํ์ฌ source batch statistics์ test batch statistics๋ฅผ ๊ฒฐํฉํ๋ค.
ฮผ~โ=ฮฑฮผ+(1โฮฑ)ฮผsโ,ฯ~2=ฮฑฯ2+(1โฮฑ)ฯs2โ+ฮฑ(1โฮฑ)(ฮผโฮผsโ)2(2)

- ์ ๊ทธ๋ฆผ์์ ๋ณผ ์ ์๋ฏ, CBN๊ณผ TBN์ ฮฑ๋ก ๊ฒฐํฉํ์ฌ ์ฌ์ฉํ๋ค.
- ฮฑ๋ post-train ๋จ๊ณ์์ ํ์ต๋๊ณ , ์ดํ ๊ณ ์ ๋๋ค.
- ์ด ๊ณผ์ ์ ํตํด ๊ฐ ์ฑ๋๊ณผ ๋ ์ด์ด๋ง๋ค ๋ค๋ฅธ ฮฑcโ๋ฅผ ๊ฐ๊ฒ ๋๋ค.
2.3. Post Training
์์์ ๋งํ ฮฑ๋ฅผ ํ์ต์ํค๋ post-training์ ๋ํด ์ค๋ช
ํ๋ค.
- ฮฑ๋ฅผ ์ ์ธํ ๋ชจ๋ ํ๋ผ๋ฏธํฐ๋ ๊ณ ์ ๋๋ฉฐ, post-training ๊ณผ์ ์๋ source data์ ์ ๊ทผํ ์ ์๋ค.
- ๋จผ์ ์ด๋ค ๋ ์ด์ด์ ์ฑ๋์ด domain shift์ ๋ฏผ๊ฐํ์ง๋ฅผ ๋ํ๋ด๋ ฮฑ์ ์ฌ์ ์ง์ A๋ฅผ ๊ตฌํ๋ค.
- ๊ทธ ํ, ์ฌ์ ์ง์๊ณผ ์ถ๊ฐ์ ์ธ Objective term์ ์ด์ฉํ์ฌ ฮฑ๋ฅผ ์ต์ ํํ๋ค.

2.3.1 Obtain Prior A
- Augmentation์ ํตํด source data์ ๋ํ domain shift๋ฅผ ์๋ฎฌ๋ ์ด์
ํ๋ค.(xโฒ)
- ๋จผ์ ์ด๋ค ๋ ์ด์ด์ ์ฑ๋์ ํ์คํ ์์น๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ๋์ด์ผ ํ๋์ง ์๊ธฐ ์ํด z(l,c)์ standardized feature z^(l,c)๋ฅผ ๊ตฌํ๋ค.
- ์ด๋ฅผ domain-shift๋ z^โฒ(l,c)์ ๋น๊ตํ๋ค.
- ์ฌ์ ํ์ต๋ CBN์ด ๋ ์
๋ ฅ์ ๋ํด ๋์ผํ ฮผs(l,c)โ, ฯs(l,c)โ๋ฅผ ์ฌ์ฉํ๋ฏ๋ก, z^โฒ(l,c)์ z^(l,c)์ ์ฐจ์ด๋ x์ xโฒ์ ์ฐจ์ด๋ก ์ธํด ๋ฐ์ํ๋ค.
- ๋ง์ฝ ์ด ์ฐจ์ด๊ฐ ํฌ๋ค๋ฉด (l,c)๋ domain shift์ ๋ฏผ๊ฐํ๋ค๊ณ ํ๋จํ ์ ์๋ค.
- ์ด๋ฅผ ์ดํ์ธ ํ๋ผ๋ฏธํฐ ฮณ,ฮฒ์ gradient โฮณโ,โฮฒโ ๋ฅผ ๋น๊ตํจ์ผ๋ก์จ ์ธก์ ํ ์ ์๋ค.
- ๊ทธ๋ฆฌ๊ณ ์ ๊ทธ๋ฆผ์์ ๋ณผ ์ ์๋ฏ, ์ฐ๋ฆฌ๋ cross-entropy loss LCEโ๋ฅผ ์ด์ฉํด์ โฮณโ,โฮฒโ ๋ฅผ ์ป์ ์ ์๋ค.
- ์ต์ข
์ ์ผ๋ก, gradient distance score d(l,c)โR์ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํ๋ค.
s=N1โi=1โNโโฅgiโโฅโฅgiโฒโโฅgiโโ
giโฒโโ,(3)
d(l,c)=1โ21โ(sฮณ(l,c)โ+sฮฒ(l,c)โ),(4)
- ์ฌ๊ธฐ์ (g,gโฒ)๋ sฮณ(l,c)โ์ sฮฒ(l,c)โ์ ๋ํด (โฮณ(l,c)โ,โฮณโฒ(l,c)โ), (โฮฒ(l,c)โ,โฮฒโฒ(l,c)โ)๋ฅผ ์๋ฏธํ๋ค.
- N์ ํ๋ จ ๋ฐ์ดํฐ ์๋ฅผ ์๋ฏธํ๋ฏ๋ก d(l,c)๋ [0,1]์ ๊ฐ์ ๊ฐ๋๋ค.
- ์๋์ ์ธ ์ฐจ์ด๋ฅผ ๊ฐ์กฐํ๊ธฐ ์ํด, ์ฐ๋ฆฌ๋ ์ต์ข
์ ์ผ๋ก ์ ๊ณฑ์ ์ทจํ์ฌ ์ฌ์ ์ง์ A๋ฅผ ๊ตฌํ๋ค.
A=[d(1,.),d(2,.),โฆ,d(L,.)]2,(5)
- ์ฌ๊ธฐ์ d(l,.)๋ [d(l,c)]c=1Clโโ์ ์๋ฏธํ๋ค.
2.3.2 Optimize ฮฑ
- ์ฌ์ A๊ฐ ์ป์ด์ง ์ดํ์ ์ฐ๋ฆฌ๋ ์ดํ์ธ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ์งํ ์ฑ๋ก CBN์ TTN๋ ์ด์ด๋ก ๋์ฒดํ ์ ์๋ค.
- ๊ทธ๋ค์ ฮฑ๋ฅผ A๋ก ์ด๊ธฐํํ๋ค.
- ๋ถํฌ ๋ณํ๋ฅผ ์๋ฎฌ๋ ์ด์
ํ๊ธฐ ์ํด, ์ฐ๋ฆฌ๋ augmented training data๋ฅผ ์ฌ์ฉํ๋ค.
- ๋ชจ๋ธ์ด ๋ณธ๋ input๊ณผ augmented input์ ๋ํด ๋์ผํ ์ฑ๋ฅ์ ๋ด๋๋ก ฮฑ๋ฅผ ์ต์ ํํ๊ธฐ ์ํด cross-entropy loss LCEโ๋ฅผ ์ฌ์ฉํ๋ค.
- ๋ํ ฮฑ๊ฐ ๋ณธ๋ A์์ ๋๋ฌด ๋ฉ์ด์ง์ง ์๋๋ก, mean-squared error loss LMSEโ=โฅฮฑโAโฅ2๋ฅผ ์ถ๊ฐํ๋ค.
- ์ต์ข
loss L์ L=LCEโ+ฮปLMSEโ(6)๋ก ์ ์๋๋ฉฐ, ฮป๋ weighting hyperparameter์ด๋ค.