๐Ÿ“ฆ๏ธ TENT: Fully Test-Time Adaptation by Entropy Minimization

Bardยท6์ผ ์ „
1

RTCL

๋ชฉ๋ก ๋ณด๊ธฐ
7/8
post-thumbnail

Contributions

  • Source data ์—†์ด ์˜ค์ง target data ๋งŒ์œผ๋กœ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” Fully test-time adaptation์„ ๊ฐ•์กฐํ•œ๋‹ค.

  • EntropyEntropy๋ฅผ adaptation ๋ชฉํ‘œ๋กœ ์‚ฌ์šฉํ•˜๋Š” TENTTENT๋ฅผ ์ œ์•ˆํ•œ๋‹ค.

  • ์™œ๊ณก์— ๊ฐ•๊ฑดํ•จ์„ ๋ณด์ด๊ธฐ ์œ„ํ•ด ImageNet-C์—์„œ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋น„๊ตํ•œ๋‹ค.

  • Domain adaptation์ด ๊ฐ€๋Šฅํ•จ์„ ๋ณด์ด๊ธฐ ์œ„ํ•ด Digit classification (SVHN-MNIST)๊ณผ ์‹œ๋ฎฌ๋ ˆ์ด์…˜-์‹ค์ œ ํ™˜๊ฒฝ์—์„œ์˜ Semantic Segmentation์„ ๋น„๊ตํ•œ๋‹ค.


Setting: Fully Test-Time Adaptation

Source data xs,โ€…โ€Šysx^s,\; y^s๋กœ ํ•™์Šต๋œ fฮธ(x)f_\theta(x)๋Š” shifted target data xtx^t์— ๋Œ€ํ•ด ์ ์šฉ๋˜๊ธฐ ์–ด๋ ต๋‹ค.

์šฐ์„  ๊ธฐ์กด์˜ TTA๋ฅผ ์œ„ํ•œ ์—ฐ๊ตฌ๋ฅผ ๋‚˜์—ดํ•œ๋‹ค.

  • Transfer learning by fine-tuning: Supervised loss L(xt,yt)L(x^t, y^t)๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด target label์ด ํ•„์š”ํ•จ.
  • Domain Adaptation: Cross-domain loss L(xs,yt)L(x^s, y^t)๋กœ ํ•™์Šต์‹œํ‚ค๊ธฐ ๋•Œ๋ฌธ์— source ๋ฐ์ดํ„ฐ์™€ target ๋ฐ์ดํ„ฐ๊ฐ€ ๋ชจ๋‘ ํ•„์š”ํ•จ.
  • Test-time Training(TTT): Adaptation ์ด์ „์— ๋จผ์ € supervised loss L(xs,ys)L(x^s, y^s)์™€ self-supervised loss L(xs)L(x^s)๋ฅผ ์ตœ์ ํ™”ํ•˜๋„๋ก ์„ค์ •ํ•จ.

๊ทธ๋Ÿฌ๋‚˜ ์˜ˆ์ƒ ๋ชปํ•œ test data๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์—, TTT์™€ ๋ณธ ๋…ผ๋ฌธ์€ model์„ ํ…Œ์ŠคํŠธ ์ค‘ unsupervised loss L(xt)L(x^t)๋ฅผ ์ตœ์ ํ™”ํ•ด์•ผ ํ•œ๋‹ค.

Fully Test-Time Adaptation์€ ฮธ\theta๋กœ ํ‘œํ˜„๋˜๋Š” training data์™€ training loss์— ๋…๋ฆฝ์ ์ด๋ฉฐ, ์ด๋ฅผ ๋ณ€์กฐํ•˜์ง€ ์•Š๊ณ , ์ ์€ ๋ฐ์ดํ„ฐ์™€ ์—ฐ์‚ฐ์œผ๋กœ adaptation์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • ์ •๋ฆฌํ•˜๋ฉด ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

Method: Test Entropy Minimization via Feature Modulation

๋ณธ ๋…ผ๋ฌธ์€ Feature๋ฅผ ๋ณ€์กฐํ•จ์œผ๋กœ์จ ์˜ˆ์ธก ์—”ํŠธ๋กœํ”ผ๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ํ…Œ์ŠคํŠธ ์ค‘ model์„ ์ตœ์ ํ™”ํ•œ๋‹ค.

Model์€ ๋ฐ˜๋“œ์‹œ supervised task๋กœ ํ›ˆ๋ จ๋˜์–ด ์žˆ์–ด์•ผ ํ•˜๋ฉฐ, ํ™•๋ฅ ์ ์ด๊ณ , ๋ฏธ๋ถ„๊ฐ€๋Šฅํ•ด์•ผํ•œ๋‹ค.

  • ํ…Œ์ŠคํŠธ ์ค‘์—๋Š” ์–ด๋–ค ์ง€๋„๋„ ์ด๋ฃจ์–ด์ง€์ง€ ์•Š๊ธฐ์—, Model์€ ์ด๋ฏธ ํ›ˆ๋ จ๋˜์–ด์žˆ์–ด์•ผ ํ•œ๋‹ค.
  • Entropy์˜ ๊ณ„์‚ฐ์€ prediction์˜ ๋ถ„ํฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋ธ์ด ํ™•๋ฅ ๋ถ„ํฌ๋กœ ๋‚˜ํƒ€๋‚˜์•ผ ํ•œ๋‹ค.
  • ๋˜ํ•œ ๋น ๋ฅด๊ณ  iterativeํ•œ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•ด Gradient๋ฅผ ๊ณ„์‚ฐํ•ด์•ผ ํ•˜๋ฏ€๋กœ, ๋ฏธ๋ถ„๊ฐ€๋Šฅํ•ด์•ผ ํ•œ๋‹ค.

์ „ํ˜•์ ์ธ ์ง€๋„ํ•™์Šต์„ ์œ„ํ•œ DNN ๋„คํŠธ์›Œํฌ๋Š” ์ด๋ฅผ ์ถฉ์กฑํ•  ๊ฒƒ์ด๋‹ค.

Entropy Objective

Test-time์˜ ๋ชฉํ‘œ L(xt)L(x_t)๋Š” ๋ชจ๋ธ์˜ ์˜ˆ์ธก y^=fฮธ(xt)\hat{y} = f_\theta(x^t)์— ๋Œ€ํ•ด ์—”ํŠธ๋กœํ”ผ H(y^)H(\hat{y})๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

๋” ์ž์„ธํžˆ, ํด๋ž˜์Šค cc์˜ ํ™•๋ฅ  y^c\hat{y}_c์— ๋Œ€ํ•ด Shannon entropy H(y^)=โˆ’โˆ‘cp(y^c)logโกp(y^c)H(\hat{y}) = -\sum_c p(\hat{y}_c)\log p(\hat{y}_c)๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.

ํ•œ ๊ฐœ์˜ ์˜ˆ์ธก์„ ์ตœ์ ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋“  ํ™•๋ฅ ์„ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์—๊ฒŒ ๋ชฐ์•„์ฃผ๋Š” ๋ฐฉ๋ฒ•์ด ์žˆ๋‹ค.

  • ์ด๋ฅผ ๋ง‰๊ธฐ ์œ„ํ•ด Tent์—์„œ๋Š” ์—ฌ๋Ÿฌ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๋ฅผ ๋ฌถ์€ ๋ฐฐ์น˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ , ๋” ๋„“์€ ๋ฒ”์œ„์˜ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์•ˆ์ •์ ์ธ ์˜ˆ์ธก์„ ํ•  ์ˆ˜ ์žˆ๋„๋ก ๋•๋Š”๋‹ค.
  • ๋˜ํ•œ ๋ฐฐ์น˜ ๋‚ด์˜ ๋ชจ๋“  ์˜ˆ์ธก์„ ๊ณต์œ ๋œ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ์ ํ™”ํ•จ์œผ๋กœ์จ ๋ชจ๋ธ์ด ๋ฐฐ์น˜๋ฅผ ์ผ๊ด€์„ฑ์žˆ๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋„๋ก ํ•œ๋‹ค.
  • Proxy task๋„ ๊ณ ๋ คํ•  ์ˆ˜ ์žˆ์œผ๋‚˜, ์—…๋ฐ์ดํŠธ๋ฅผ ์ œํ•œํ•˜๊ฑฐ๋‚˜ ํ˜ผํ•ฉํ•ด์•ผ ํ•˜๋ฉฐ, ์ ์ ˆํ•œ ํ”„๋ก์‹œ๋ฅผ ์„ ํƒํ•˜๊ธฐ ์œ„ํ•ด ๋งค์šฐ ๋งŽ์€ ๋…ธ๋ ฅ์„ ํ•ด์•ผํ•œ๋‹ค. Entropy objective๋Š” ์ด๋Ÿฐ ๋…ธ๋ ฅ์ด ํ•„์š” ์—†๋‹ค.

Modulation Parameters

Model parameter ฮธ\theta๋Š” test-time์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์ž์—ฐ์Šค๋Ÿฌ์šด ์„ ํƒ์ด๋ฉฐ, ์„ ํ–‰ ์—ฐ๊ตฌ๋“ค์—์„œ๋„ ์‚ฌ์šฉ๋˜์—ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ฮธ\theta๋Š” source data๋ฅผ ํ‘œํ•œํ•  ์ˆ˜์žˆ๋Š” ์œ ์ผํ•œ ์ˆ˜๋‹จ์ด๋ฉฐ, ์ด๋ฅผ ๋ณ€ํ˜•ํ•˜๋Š” ๊ฒƒ์€ ๋ชจ๋ธ์ด ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋กœ๋ถ€ํ„ฐ ๋ฒ—์–ด๋‚˜๊ฒŒ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค.

๋” ๋‚˜์•„๊ฐ€ ff๋Š” ๋น„ ์„ ํ˜•์ ์ด๊ณ , ฮธ\theta๋Š” ๋†’์€ ์ฐจ์›์„ ๊ฐ–๊ธฐ ๋•Œ๋ฌธ์—, test-time์— ์‚ฌ์šฉํ•˜๊ธฐ์—๋Š” ๋„ˆ๋ฌด ๋ฏผ๊ฐํ•˜๊ณ  ๋น„ํšจ์œจ์ ์ด๋‹ค.

๋”ฐ๋ผ์„œ ์•ˆ์ •์„ฑ๊ณผ ํšจ์œจ์„ฑ์„ ์œ„ํ•ด, ์„ ํ˜•์ ์ด๊ณ  ์ €์ฐจ์›์˜ feature ๋งŒ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.

  • ๋จผ์ € Input xx๋ฅผ xห‰=(xโˆ’ฮผ)/ฯƒ\bar{x} = (x-\mu)/\sigma๋กœ ์ •๊ทœํ™”ํ•œ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ , xห‰\bar{x}๋ฅผ affine ๋งค๊ฐœ๋ณ€์ˆ˜ (scale ฮณ\gamma์™€ shift ฮฒ\beta)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ xโ€ฒ=ฮณxห‰+ฮฒx' = \gamma \bar{x} + \beta๋กœ ์—…๋ฐ์ดํŠธ ํ•œ๋‹ค.

๊ตฌํ˜„์„ ์œ„ํ•ด tent๋Š” model์˜ normalization layer๋ฅผ ๋‹จ์ˆœํžˆ ์žฌํ™œ์šฉํ•˜์—ฌ ํ…Œ์ŠคํŠธ ๋™์•ˆ ๋ชจ๋“  ๋ ˆ์ด์–ด์— ๋Œ€ํ•ด ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์„ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.


Algorithm

Initialization

  • Optimizer๊ฐ€ source model์˜ ๊ฐ ์ •๊ทœํ™” ๋ ˆ์ด์–ด ll๊ณผ ์ฑ„๋„ kk์— ๋Œ€ํ•ด ์•„ํ•€๋ณ€ํ™˜ ๋งค๊ฐœ๋ณ€์ˆ˜ {ฮณl,k,ฮฒl,k}\{\gamma_{l, k}, \beta_{l, k}\}๋ฅผ ์ˆ˜์ง‘ํ•œ๋‹ค.
  • ๋‚จ์•„์žˆ๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ฮธโ€…โ€Šโˆ–{ฮณl,k,ฮฒl,k}\theta\;\setminus \{\gamma_{l, k}, \beta_{l, k}\}๋Š” ๊ทธ๋Œ€๋กœ ๊ณ ์ •๋˜์–ด ์žˆ๋‹ค.
def setup_tent(model):
    model = tent.configure_model(model)
    params, param_names = tent.collect_params(model) // parameter ์ˆ˜์ง‘
    optimizer = setup_optimizer(params) // optimizer์—๊ฒŒ ์ „๋‹ฌ
    tent_model = tent.Tent(model, optimizer, steps=cfg.OPTIM.STEPS, episodic=cfg.MODEL.EPISODIC)
    return tent_model

def setup_optimizer(params):
    if cfg.OPTIM.METHOD == 'Adam':
        return optim.Adam(params, lr=cfg.OPTIM.LR, ... ) // optimize ๋Œ€์ƒ์„ params๋กœ ์„ค์ •
def collect_params(model):
    params = []
    names = []
    for nm, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d): // ๋ฐฐ์น˜ ์ •๊ทœํ™” ๋ ˆ์ด์–ด
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  # weight is scale, bias is shift
                    params.append(p)
                    names.append(f"{nm}.{np}")
    return params, names

Iteration

๋งค ์Šคํ…๋งˆ๋‹ค normalization statistics์™€ transformation parameter๋ฅผ ๊ฐ ๋ฐฐ์น˜๋งˆ๋‹ค ์—…๋ฐ์ดํŠธํ•œ๋‹ค.

  • Normalization statistics๋Š” forward pass ์ค‘ ๊ณ„์‚ฐ๋œ๋‹ค.
  • Transformation parameter ฮณ,ฮฒ\gamma, \beta๋Š” backward pass ์ค‘ ์˜ˆ์ธก ์—”ํŠธ๋กœํ”ผ์˜ gradient โˆ‡H(y^)\nabla H(\hat{y})๋ฅผ ํ†ตํ•ด ์—…๋ฐ์ดํŠธ๋œ๋‹ค.
class Tent(nn.Module):
    ...
    def forward(self, x):
        ...

        for _ in range(self.steps):
            outputs = forward_and_adapt(x, self.model, self.optimizer)

        return outputs
@torch.enable_grad()  # ensure grads in possible no grad context for testing
def forward_and_adapt(x, model, optimizer):
    outputs = model(x) // output ๊ณ„์‚ฐ

    loss = softmax_entropy(outputs).mean(0) //  -(x.softmax(1) * x.log_softmax(1)).sum(1)
    loss.backward() // gradient ๊ณ„์‚ฐ
    optimizer.step() // parameter ์—…๋ฐ์ดํŠธ
    optimizer.zero_grad() // gradient ์ดˆ๊ธฐํ™”

    return outputs

Termination

  • Online adaptation์—์„œ termination์€ ํ•„์š”์—†์Šต๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋Š” ํ•œ ๊ณ„์† ์ˆ˜ํ–‰๋œ๋‹ค.
  • Offline adaptation์—์„œ๋Š” ๋จผ์ € ์—…๋ฐ์ดํŠธ๋ฅผ ์ˆ˜ํ–‰ํ•œ ํ›„, ์—ฌ๋Ÿฌ Epoch์— ๊ฑธ์ณ ์ถ”๋ก ์„ ๋ฐ˜๋ณตํ•œ๋‹ค.

Experiments

Datasets

  • Image classification for corruption

    • ImageNet (1000 class, 1.2M training set + 50,000 validation set)
    • CIFAR-10/CIFAR-100 (10/100 class, 50,000 training set + 10,000 test set)
  • Image classification with domain adaptation

    • SVHN as source (73,257 training set + 26,032 test set)
    • MNIST/MNIST-M/USPS as targets (60,000/60,000/7,291 training set + 10,000/10,000/2,007 test set)

Models

  • ResNet์„ ์‚ฌ์šฉํ•œ๋‹ค.
  • CIFAR-10/100์— ๋Œ€ํ•ด์„œ๋Š” 26๋ ˆ์ด์–ด(R-26), ImageNet์—์„œ๋Š” 50๋ ˆ์ด์–ด(R-50)๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

Optimization

  • ImageNet์€ SGD with momentum, ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์…‹์—์„œ๋Š” Adam์„ ์‚ฌ์šฉํ•œ๋‹ค.
  • ImageNet์—์„œ๋Š” Learning Rate = 0.00025, Batch size = 64
  • ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์…‹์—์„œ๋Š” Learning Rate = 0.001, Batch size = 128์„ ์‚ฌ์šฉํ•œ๋‹ค.

Baselines

๋Œ€์กฐ๊ตฐ์„ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์„ค์ •ํ•œ๋‹ค.

  • Source: Adaptation์„ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š์€ classifier
  • RG: Adversarial domain adaptation
  • UDA-SS: Self-supervised domain adaptation
  • TTT: Test Time Training
  • BN: Target data์— ๋Œ€ํ•œ batch normalization
  • PL: Pseudo labeling: ์ผ์ • threshold ์ด์ƒ์˜ ์˜ˆ์ธก์„ label๋กœ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ด.

BN, PL, TENT ๋งŒ์ด fully test-time adaptaion์ด๋ฉฐ, ๋‚˜๋จธ์ง€๋Š” source ๋ฐ์ดํ„ฐ๊ฐ€ ํ•„์š”ํ•˜๋‹ค.

Robustness to Corruption

15๊ฐ€์ง€ ์œ ํ˜•์˜ corruption์„ ๋‹ค์„ฏ๊ฐ€์ง€ ์‹ฌ๊ฐ๋„๋กœ ์ ์šฉํ•˜์—ฌ CIFAR-10/100, ImageNet-C์—์„œ ํ…Œ์ŠคํŠธํ•œ๋‹ค.

Source-Free Domain Adaptation

SVHN(๊ฑฐ๋ฆฌ์˜ ์ง‘ ๋ฒˆํ˜ธ, ์œ ์ฑ„์ƒ‰) -> MNIST(์†๊ธ€์”จ ์ˆซ์ž, ๋ฌด์ฑ„์ƒ‰)

  • Source ๋ฐ์ดํ„ฐ ์—†์ด๋„ adaptation์ด ๊ฐ€๋Šฅํ•œ ๊ฒƒ์„ ๋ณด์ž„.
  • ๋” ์ ๊ฒŒ ์—ฐ์‚ฐํ•˜์ง€๋งŒ ๋” ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ž„.

GTA(๊ฒŒ์ž„) -> CityScape(์ž์œจ์ฃผํ–‰ ๋ฐ์ดํ„ฐ์…‹) ์—์„œ์˜ Semantic Segmentation ๊ฒฐ๊ณผ


Analysis

  • ๋Œ€๋ถ€๋ถ„์˜ ์ ์ด ์™ผ์ชฝ ์•„๋ž˜์— ์žˆ์Œ -> TENT๊ฐ€ Loss์™€ Entropy๋ฅผ ๋ชจ๋‘ ๊ฐ์†Œ์‹œํ‚ด
  • ์ง„ํ•œ ๋Œ€๊ฐ์„  (ฯ=0.22\rho = 0.22) -> Loss์™€ Entropy๊ฐ„์— ์ƒ๊ด€๊ด€๊ณ„๊ฐ€ ์กด์žฌํ•จ.

Ablation study

  • Normalization์„ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ: BN์ด๋‚˜ PL๋ณด๋‹ค ์—๋Ÿฌ๊ฐ€ ๋†’์•„์ง.
  • Transformation์„ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ: BN

Alternative architecture

  • Tent๋Š” model-agnosticํ•จ.
  • ์•„๋ž˜ ํ‘œ๋Š” Self-Attention (SAN)๊ณผ Equilibrium solving (MDEQ)์— ๋Œ€ํ•ด corruption robustness๋ฅผ ํ‰๊ฐ€ํ•œ ๊ฒฐ๊ณผ์ž„.

profile
Recently broke up with FE engineering

0๊ฐœ์˜ ๋Œ“๊ธ€

๊ด€๋ จ ์ฑ„์šฉ ์ •๋ณด