๐Ÿ“ฆ๏ธ TTN: A Domain-Shift Aware Batch Normalization in Test-Time Adaptation

Bardยท2025๋…„ 4์›” 7์ผ

RTCL

๋ชฉ๋ก ๋ณด๊ธฐ
9/15
post-thumbnail

Contributions

  • ์ฑ„๋„๋ณ„ ๋ณด๊ฐ„ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Source batch statistics(CBN)์™€ Test batch statistics(TBN)๋ฅผ ๊ฒฐํ•ฉํ•˜๋ฉฐ, ์ƒˆ๋กœ์šด ๋„๋ฉ”์ธ์— ๋Œ€ํ•ด ์ ์‘ํ•˜๋ฉด์„œ๋„ ์›์ฒœ์ง€์‹์„ ๋ณด์กดํ•˜๋Š” TTN ๋ ˆ์ด์–ด๋ฅผ ์ œ์•ˆํ•œ๋‹ค.
  • ๊ธฐ์กด์˜ TTA ๋ฐฉ๋ฒ•์— TTN์„ ์ถ”๊ฐ€ํ•˜๋ฉด ํ…Œ์ŠคํŠธ ๋ฐฐ์น˜ ํฌ๊ธฐ(1์—์„œ 200๊นŒ์ง€)์˜ ํญ๋„“์€ ๋ฒ”์œ„์—์„œ ์„ฑ๋Šฅ์ด ํฌ๊ฒŒ ํ–ฅ์ƒ๋จ์„ ๋ณด์ธ๋‹ค.

2. Methodology

2.1. Problem Setup

  • ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ DS,โ€…โ€ŠDT\mathcal{D}_S,\;\mathcal{D}_T, ๊ฐ๊ฐ์— ๋Œ€ํ•œ ํ™•๋ฅ ๋ถ„ํฌ๋ฅผ PS,โ€…โ€ŠPTP_S,\;P_T๋ผ ํ•˜์ž.
  • TTA์˜ Covariate Shift๋Š” PSโ‰ PTP_S \neq P_T where PS(yโˆฃx)=PT(yโˆฃx)P_S(y|x) = P_T(y|x)์ด๋‹ค.
  • ๋ชจ๋ธ fฮธf_\theta๋Š” DS\mathcal{D}_S์˜ mini-batch BS={(xi,yi)}i=1โˆฃBSโˆฃ\mathcal{B}^S = \left\{(x_i, y_i)\right\}^{|\mathcal{B}^S|}_{i=1}๋กœ ํ•™์Šต๋˜์—ˆ์œผ๋ฉฐ,
  • ํ…Œ์ŠคํŠธ ์ค‘์—๋Š” fฮธf_\theta๋Š” ํ…Œ์ŠคํŠธ ๋ฐฐ์น˜ BTโˆผDT\mathcal{B}_T\sim\mathcal{D}_T๋ฅผ ๋งŒ๋‚˜๊ฒŒ ๋œ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  TTA์˜ ๋ชฉํ‘œ๋Š” ๋‹ค๋ฅธ ๋ถ„ํฌ๋กœ๋ถ€ํ„ฐ ํ…Œ์ŠคํŠธ ๋ฐฐ์น˜๋ฅผ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

์ข€ ๋” ์‹ค์šฉ์ ์ธ TTA๋ฅผ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ํ•˜๊ธฐ ์œ„ํ•ด ์šฐ๋ฆฌ๋Š” ๋‘ ๊ฐœ์˜ ๋ณ€ํ™”๋ฅผ ์ค‘์ ์ ์œผ๋กœ ๊ณ ๋ คํ•œ๋‹ค.

  1. ๋‹ค์–‘ํ•œ ํ…Œ์ŠคํŠธ ๋ฐฐ์น˜์˜ ํฌ๊ธฐ โˆฃBTโˆฃ|\mathcal{B}^T| - ์ž‘์€ ๋ฐฐ์น˜์‚ฌ์ด์ฆˆ๋Š” ์งง์€ ์ง€์—ฐ์œผ๋กœ ์ด์–ด์ง„๋‹ค.
  2. ์—ฌ๋Ÿฌ ๊ฐœ, NN ๊ฐœ์˜ ๋„๋ฉ”์ธ DT={DT,i}i=1N\mathcal{D}_T = \{\mathcal{D}_{T,i}\}^N_{i=1}์œผ๋กœ์˜ ์ ์‘

2.2. Test-Time Normalization Layer

  • BN ๋ ˆ์ด์–ด์˜ ์ž…๋ ฅ์„ zโˆˆRBCHW\mathrm{z}\in\mathbb{R}^{BCHW}๋ผ ํ•˜์ž. (๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ BB, ์ฑ„๋„ ์ˆ˜ CC, ๋†’์ด์™€ ๋„ˆ๋น„ H,WH,W)
  • z\mathrm{z}์˜ ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์€ ฮผ\mu์™€ ฯƒ2\sigma^2์ด๊ณ , ๊ฐ๊ฐ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋‚˜ํƒ€๋‚œ๋‹ค.
    ฮผc=1BHWโˆ‘bBโˆ‘hHโˆ‘wWzbchw,ฯƒc2=1BHWโˆ‘bBโˆ‘hHโˆ‘wW(zbchwโˆ’ฮผc)2(1)\mu_c = \frac 1 {BHW} \sum^B_b\sum^H_h\sum^W_w\mathrm{z}_{bchw},\qquad \sigma_c^2 = \frac 1 {BHW} \sum^B_b\sum^H_h\sum^W_w(\mathrm{z}_{bchw}-\mu_c)^2 \tag{1}
  • BN๋ ˆ์ด์–ด์—์„œ z\mathrm{z}๋Š” ๋จผ์ € ฮผ\mu์™€ ฯƒ2\sigma^2๋กœ ํ‘œ์ค€ํ™”๋œ ๋’ค, ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ ฮณ,ฮฒโˆˆRC\gamma, \beta \in \mathbb{R}^C๋กœ affine ๋ณ€ํ™˜๋œ๋‹ค.
  • ํ‘œ์ค€ํ™”๋Š” ํ˜„์žฌ input batch statistics๋ฅผ ํ›ˆ๋ จ ์ค‘์— ์‚ฌ์šฉํ•˜๊ณ , ํ…Œ์ŠคํŠธ ์ค‘์—๋Š” ์ถ”์ •๋œ source statistics ฮผs\mu_s์™€ ฯƒs2\sigma_s^2๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
  • ๋˜ํ•œ domain shift๋ฅผ ๊ณ ๋ คํ•˜๊ธฐ ์œ„ํ•ด, ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๋ณด๊ฐ„ ๊ฐ€์ค‘์น˜ ฮฑโˆˆR\alpha \in \mathbb{R}์„ ์ด์šฉํ•˜์—ฌ source batch statistics์™€ test batch statistics๋ฅผ ๊ฒฐํ•ฉํ•œ๋‹ค.
    ฮผ~=ฮฑฮผ+(1โˆ’ฮฑ)ฮผs,ฯƒ~2=ฮฑฯƒ2+(1โˆ’ฮฑ)ฯƒs2+ฮฑ(1โˆ’ฮฑ)(ฮผโˆ’ฮผs)2(2)\tilde{\mu} = \alpha \mu + (1 - \alpha) \mu_s, \quad \tilde{\sigma}^2 = \alpha \sigma^2 + (1 - \alpha) \sigma_s^2 + \alpha (1 - \alpha)(\mu - \mu_s)^2 \tag{2}

  • ์œ„ ๊ทธ๋ฆผ์—์„œ ๋ณผ ์ˆ˜ ์žˆ๋“ฏ, CBN๊ณผ TBN์„ ฮฑ\alpha๋กœ ๊ฒฐํ•ฉํ•˜์—ฌ ์‚ฌ์šฉํ•œ๋‹ค.
  • ฮฑ\alpha๋Š” post-train ๋‹จ๊ณ„์—์„œ ํ•™์Šต๋˜๊ณ , ์ดํ›„ ๊ณ ์ •๋œ๋‹ค.
  • ์ด ๊ณผ์ •์„ ํ†ตํ•ด ๊ฐ ์ฑ„๋„๊ณผ ๋ ˆ์ด์–ด๋งˆ๋‹ค ๋‹ค๋ฅธ ฮฑc\alpha_c๋ฅผ ๊ฐ–๊ฒŒ ๋œ๋‹ค.

2.3. Post Training

์•ž์—์„œ ๋งํ•œ ฮฑ\alpha๋ฅผ ํ•™์Šต์‹œํ‚ค๋Š” post-training์— ๋Œ€ํ•ด ์„ค๋ช…ํ•œ๋‹ค.

  • ฮฑ\alpha๋ฅผ ์ œ์™ธํ•œ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ๊ณ ์ •๋˜๋ฉฐ, post-training ๊ณผ์ •์—๋Š” source data์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ๋จผ์ € ์–ด๋–ค ๋ ˆ์ด์–ด์™€ ์ฑ„๋„์ด domain shift์— ๋ฏผ๊ฐํ•œ์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ฮฑ\alpha์˜ ์‚ฌ์ „์ง€์‹ A\mathcal{A}๋ฅผ ๊ตฌํ•œ๋‹ค.
  • ๊ทธ ํ›„, ์‚ฌ์ „ ์ง€์‹๊ณผ ์ถ”๊ฐ€์ ์ธ Objective term์„ ์ด์šฉํ•˜์—ฌ ฮฑ\alpha๋ฅผ ์ตœ์ ํ™”ํ•œ๋‹ค.

2.3.1 Obtain Prior A\mathcal{A}

  • Augmentation์„ ํ†ตํ•ด source data์— ๋Œ€ํ•œ domain shift๋ฅผ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ํ•œ๋‹ค.(xโ€ฒx')
  • ๋จผ์ € ์–ด๋–ค ๋ ˆ์ด์–ด์™€ ์ฑ„๋„์˜ ํ‘œ์ค€ํ™” ์ˆ˜์น˜๊ฐ€ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๋˜์–ด์•ผ ํ•˜๋Š”์ง€ ์•Œ๊ธฐ ์œ„ํ•ด z(l,c)z^{(l, c)}์˜ standardized feature z^(l,c)\hat{z}^{(l,c)}๋ฅผ ๊ตฌํ•œ๋‹ค.
  • ์ด๋ฅผ domain-shift๋œ z^โ€ฒ(l,c)\hat{z}'^{(l,c)}์™€ ๋น„๊ตํ•œ๋‹ค.
  • ์‚ฌ์ „ ํ•™์Šต๋œ CBN์ด ๋‘ ์ž…๋ ฅ์— ๋Œ€ํ•ด ๋™์ผํ•œ ฮผs(l,c)\mu_s^{(l, c)}, ฯƒs(l,c)\sigma_s^{(l, c)}๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ, z^โ€ฒ(l,c)\hat{z}'^{(l,c)}์™€ z^(l,c)\hat{z}^{(l,c)}์˜ ์ฐจ์ด๋Š” xx์™€ xโ€ฒx'์˜ ์ฐจ์ด๋กœ ์ธํ•ด ๋ฐœ์ƒํ•œ๋‹ค.
  • ๋งŒ์•ฝ ์ด ์ฐจ์ด๊ฐ€ ํฌ๋‹ค๋ฉด (l,c)(l, c)๋Š” domain shift์— ๋ฏผ๊ฐํ•˜๋‹ค๊ณ  ํŒ๋‹จํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ์ด๋ฅผ ์–ดํŒŒ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ ฮณ,ฮฒ\gamma, \beta์˜ gradient โˆ‡ฮณ,โˆ‡ฮฒ\nabla_\gamma, \nabla_\beta ๋ฅผ ๋น„๊ตํ•จ์œผ๋กœ์จ ์ธก์ •ํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  ์œ„ ๊ทธ๋ฆผ์—์„œ ๋ณผ ์ˆ˜ ์žˆ๋“ฏ, ์šฐ๋ฆฌ๋Š” cross-entropy loss LCE\mathcal{L}_{CE}๋ฅผ ์ด์šฉํ•ด์„œ โˆ‡ฮณ,โˆ‡ฮฒ\nabla_\gamma, \nabla_\beta ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค.
  • ์ตœ์ข…์ ์œผ๋กœ, gradient distance score d(l,c)โˆˆRd^{(l,c)} \in \mathbb{R}์„ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ •์˜ํ•œ๋‹ค.
s=1Nโˆ‘i=1Ngiโ‹…giโ€ฒโˆฅgiโˆฅโˆฅgiโ€ฒโˆฅ,(3)s = \frac{1}{N} \sum_{i=1}^{N} \frac{g_i \cdot g'_i}{\| g_i \| \| g'_i \|}, \tag{3}
d(l,c)=1โˆ’12(sฮณ(l,c)+sฮฒ(l,c)),(4)d^{(l,c)} = 1 - \frac{1}{2} (s^{(l,c)}_{\gamma} + s^{(l,c)}_{\beta}), \tag{4}
  • ์—ฌ๊ธฐ์„œ (g,gโ€ฒ)(g,g')๋Š” sฮณ(l,c)s_\gamma^{(l,c)}์™€ sฮฒ(l,c)s_\beta^{(l,c)}์— ๋Œ€ํ•ด (โˆ‡ฮณ(l,c),โˆ‡ฮณโ€ฒ(l,c))(\nabla_\gamma^{(l,c)}, \nabla_{\gamma'}^{(l,c)}), (โˆ‡ฮฒ(l,c),โˆ‡ฮฒโ€ฒ(l,c))(\nabla_\beta^{(l,c)}, \nabla_{\beta'}^{(l,c)})๋ฅผ ์˜๋ฏธํ•œ๋‹ค.
  • NN์€ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ์ˆ˜๋ฅผ ์˜๋ฏธํ•˜๋ฏ€๋กœ d(l,c)d^{(l,c)}๋Š” [0,1][0,1]์˜ ๊ฐ’์„ ๊ฐ–๋Š”๋‹ค.
  • ์ƒ๋Œ€์ ์ธ ์ฐจ์ด๋ฅผ ๊ฐ•์กฐํ•˜๊ธฐ ์œ„ํ•ด, ์šฐ๋ฆฌ๋Š” ์ตœ์ข…์ ์œผ๋กœ ์ œ๊ณฑ์„ ์ทจํ•˜์—ฌ ์‚ฌ์ „์ง€์‹ A\mathcal{A}๋ฅผ ๊ตฌํ•œ๋‹ค.
    A=[d(1,.),d(2,.),โ€ฆ,d(L,.)]2,(5)\mathcal{A} = [d^{(1,.)}, d^{(2,.)}, \ldots, d^{(L,.)}]^2, \tag{5}
  • ์—ฌ๊ธฐ์„œ d(l,.)d^{(l,.)}๋Š” [d(l,c)]c=1Cl[d^{(l,c)}]^{C_l}_{c=1}์„ ์˜๋ฏธํ•œ๋‹ค.

2.3.2 Optimize ฮฑ\alpha

  • ์‚ฌ์ „ A\mathcal{A}๊ฐ€ ์–ป์–ด์ง„ ์ดํ›„์— ์šฐ๋ฆฌ๋Š” ์–ดํŒŒ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์œ ์ง€ํ•œ ์ฑ„๋กœ CBN์„ TTN๋ ˆ์ด์–ด๋กœ ๋Œ€์ฒดํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ๊ทธ๋‹ค์Œ ฮฑ\alpha๋ฅผ A\mathcal{A}๋กœ ์ดˆ๊ธฐํ™”ํ•œ๋‹ค.
  • ๋ถ„ํฌ ๋ณ€ํ™”๋ฅผ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ํ•˜๊ธฐ ์œ„ํ•ด, ์šฐ๋ฆฌ๋Š” augmented training data๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
  • ๋ชจ๋ธ์ด ๋ณธ๋ž˜ input๊ณผ augmented input์— ๋Œ€ํ•ด ๋™์ผํ•œ ์„ฑ๋Šฅ์„ ๋‚ด๋„๋ก ฮฑ\alpha๋ฅผ ์ตœ์ ํ™”ํ•˜๊ธฐ ์œ„ํ•ด cross-entropy loss LCE\mathcal{L}_{CE}๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
  • ๋˜ํ•œ ฮฑ\alpha๊ฐ€ ๋ณธ๋ž˜ A\mathcal{A}์—์„œ ๋„ˆ๋ฌด ๋ฉ€์–ด์ง€์ง€ ์•Š๋„๋ก, mean-squared error loss LMSE=โˆฅฮฑโˆ’Aโˆฅ2\mathcal{L}_{MSE} = \lVert\alpha-\mathcal{A}\rVert^2๋ฅผ ์ถ”๊ฐ€ํ•œ๋‹ค.
  • ์ตœ์ข… loss L\mathcal{L}์€ L=LCE+ฮปLMSEโ€…โ€Š(6)\mathcal{L} = \mathcal{L}_{CE} + \lambda \mathcal{L}_{MSE}\;(6)๋กœ ์ •์˜๋˜๋ฉฐ, ฮป\lambda๋Š” weighting hyperparameter์ด๋‹ค.
profile
๋ˆ ๋˜๋Š” ๊ฑด ๋‹ค ๊ณต๋ถ€ํ•ฉ๋‹ˆ๋‹ค.

0๊ฐœ์˜ ๋Œ“๊ธ€