Source data ์์ด ์ค์ง target data ๋ง์ผ๋ก ์งํํ ์ ์๋ Fully test-time adaptation์ ๊ฐ์กฐํ๋ค.
๋ฅผ adaptation ๋ชฉํ๋ก ์ฌ์ฉํ๋ ๋ฅผ ์ ์ํ๋ค.
์๊ณก์ ๊ฐ๊ฑดํจ์ ๋ณด์ด๊ธฐ ์ํด ImageNet-C์์์ ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํ๋ค.
Domain adaptation์ด ๊ฐ๋ฅํจ์ ๋ณด์ด๊ธฐ ์ํด Digit classification (SVHN-MNIST)๊ณผ ์๋ฎฌ๋ ์ด์ -์ค์ ํ๊ฒฝ์์์ Semantic Segmentation์ ๋น๊ตํ๋ค.
Source data ๋ก ํ์ต๋ ๋ shifted target data ์ ๋ํด ์ ์ฉ๋๊ธฐ ์ด๋ ต๋ค.
์ฐ์ ๊ธฐ์กด์ TTA๋ฅผ ์ํ ์ฐ๊ตฌ๋ฅผ ๋์ดํ๋ค.
๊ทธ๋ฌ๋ ์์ ๋ชปํ test data๊ฐ ์๊ธฐ ๋๋ฌธ์, TTT์ ๋ณธ ๋ ผ๋ฌธ์ model์ ํ ์คํธ ์ค unsupervised loss ๋ฅผ ์ต์ ํํด์ผ ํ๋ค.
Fully Test-Time Adaptation์ ๋ก ํํ๋๋ training data์ training loss์ ๋ ๋ฆฝ์ ์ด๋ฉฐ, ์ด๋ฅผ ๋ณ์กฐํ์ง ์๊ณ , ์ ์ ๋ฐ์ดํฐ์ ์ฐ์ฐ์ผ๋ก adaptation์ ์ํํ๋ค.
๋ณธ ๋ ผ๋ฌธ์ Feature๋ฅผ ๋ณ์กฐํจ์ผ๋ก์จ ์์ธก ์ํธ๋กํผ๋ฅผ ์ต์ํํ๋ ๋ฐฉํฅ์ผ๋ก ํ ์คํธ ์ค model์ ์ต์ ํํ๋ค.
Model์ ๋ฐ๋์ supervised task๋ก ํ๋ จ๋์ด ์์ด์ผ ํ๋ฉฐ, ํ๋ฅ ์ ์ด๊ณ , ๋ฏธ๋ถ๊ฐ๋ฅํด์ผํ๋ค.
์ ํ์ ์ธ ์ง๋ํ์ต์ ์ํ DNN ๋คํธ์ํฌ๋ ์ด๋ฅผ ์ถฉ์กฑํ ๊ฒ์ด๋ค.
Test-time์ ๋ชฉํ ๋ ๋ชจ๋ธ์ ์์ธก ์ ๋ํด ์ํธ๋กํผ ๋ฅผ ์ต์ํํ๋ ๊ฒ์ด๋ค.
๋ ์์ธํ, ํด๋์ค ์ ํ๋ฅ ์ ๋ํด Shannon entropy ๋ฅผ ๊ณ์ฐํ๋ค.
ํ ๊ฐ์ ์์ธก์ ์ต์ ํํ๊ธฐ ์ํด ๋ชจ๋ ํ๋ฅ ์ ๊ฐ์ฅ ๋์ ํ๋ฅ ์๊ฒ ๋ชฐ์์ฃผ๋ ๋ฐฉ๋ฒ์ด ์๋ค.
- ์ด๋ฅผ ๋ง๊ธฐ ์ํด Tent์์๋ ์ฌ๋ฌ ๋ฐ์ดํฐ ํฌ์ธํธ๋ฅผ ๋ฌถ์ ๋ฐฐ์น๋ฅผ ์ฌ์ฉํ๊ณ , ๋ ๋์ ๋ฒ์์ ์ ๋ ฅ์ ๋ํด ์์ ์ ์ธ ์์ธก์ ํ ์ ์๋๋ก ๋๋๋ค.
- ๋ํ ๋ฐฐ์น ๋ด์ ๋ชจ๋ ์์ธก์ ๊ณต์ ๋ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ์ฌ ์ต์ ํํจ์ผ๋ก์จ ๋ชจ๋ธ์ด ๋ฐฐ์น๋ฅผ ์ผ๊ด์ฑ์๊ฒ ์ฒ๋ฆฌํ๋๋ก ํ๋ค.
Model parameter ๋ test-time์์ ์ฌ์ฉํ ์ ์๋ ์์ฐ์ค๋ฌ์ด ์ ํ์ด๋ฉฐ, ์ ํ ์ฐ๊ตฌ๋ค์์๋ ์ฌ์ฉ๋์๋ค.
๊ทธ๋ฌ๋ ๋ source data๋ฅผ ํํํ ์์๋ ์ ์ผํ ์๋จ์ด๋ฉฐ, ์ด๋ฅผ ๋ณํํ๋ ๊ฒ์ ๋ชจ๋ธ์ด ํ๋ จ ๋ฐ์ดํฐ๋ก๋ถํฐ ๋ฒ์ด๋๊ฒ ๋ง๋ค ์ ์๋ค.
๋ ๋์๊ฐ ๋ ๋น ์ ํ์ ์ด๊ณ , ๋ ๋์ ์ฐจ์์ ๊ฐ๊ธฐ ๋๋ฌธ์, test-time์ ์ฌ์ฉํ๊ธฐ์๋ ๋๋ฌด ๋ฏผ๊ฐํ๊ณ ๋นํจ์จ์ ์ด๋ค.
๋ฐ๋ผ์ ์์ ์ฑ๊ณผ ํจ์จ์ฑ์ ์ํด, ์ ํ์ ์ด๊ณ ์ ์ฐจ์์ feature ๋ง ์ ๋ฐ์ดํธํ๋ค.
๊ตฌํ์ ์ํด tent๋ model์ normalization layer๋ฅผ ๋จ์ํ ์ฌํ์ฉํ์ฌ ํ ์คํธ ๋์ ๋ชจ๋ ๋ ์ด์ด์ ๋ํด ๋งค๊ฐ๋ณ์๋ค์ ์ ๋ฐ์ดํธํ๋ค.
def setup_tent(model):
model = tent.configure_model(model)
params, param_names = tent.collect_params(model) // parameter ์์ง
optimizer = setup_optimizer(params) // optimizer์๊ฒ ์ ๋ฌ
tent_model = tent.Tent(model, optimizer, steps=cfg.OPTIM.STEPS, episodic=cfg.MODEL.EPISODIC)
return tent_model
def setup_optimizer(params):
if cfg.OPTIM.METHOD == 'Adam':
return optim.Adam(params, lr=cfg.OPTIM.LR, ... ) // optimize ๋์์ params๋ก ์ค์
def collect_params(model):
params = []
names = []
for nm, m in model.named_modules():
if isinstance(m, nn.BatchNorm2d): // ๋ฐฐ์น ์ ๊ทํ ๋ ์ด์ด
for np, p in m.named_parameters():
if np in ['weight', 'bias']: # weight is scale, bias is shift
params.append(p)
names.append(f"{nm}.{np}")
return params, names
๋งค ์คํ ๋ง๋ค normalization statistics์ transformation parameter๋ฅผ ๊ฐ ๋ฐฐ์น๋ง๋ค ์ ๋ฐ์ดํธํ๋ค.
class Tent(nn.Module):
...
def forward(self, x):
...
for _ in range(self.steps):
outputs = forward_and_adapt(x, self.model, self.optimizer)
return outputs
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt(x, model, optimizer):
outputs = model(x) // output ๊ณ์ฐ
loss = softmax_entropy(outputs).mean(0) // -(x.softmax(1) * x.log_softmax(1)).sum(1)
loss.backward() // gradient ๊ณ์ฐ
optimizer.step() // parameter ์
๋ฐ์ดํธ
optimizer.zero_grad() // gradient ์ด๊ธฐํ
return outputs
Image classification for corruption
Image classification with domain adaptation
๋์กฐ๊ตฐ์ ๋ค์๊ณผ ๊ฐ์ด ์ค์ ํ๋ค.
BN, PL, TENT ๋ง์ด fully test-time adaptaion์ด๋ฉฐ, ๋๋จธ์ง๋ source ๋ฐ์ดํฐ๊ฐ ํ์ํ๋ค.
15๊ฐ์ง ์ ํ์ corruption์ ๋ค์ฏ๊ฐ์ง ์ฌ๊ฐ๋๋ก ์ ์ฉํ์ฌ CIFAR-10/100, ImageNet-C์์ ํ ์คํธํ๋ค.
SVHN(๊ฑฐ๋ฆฌ์ ์ง ๋ฒํธ, ์ ์ฑ์) -> MNIST(์๊ธ์จ ์ซ์, ๋ฌด์ฑ์)
GTA(๊ฒ์) -> CityScape(์์จ์ฃผํ ๋ฐ์ดํฐ์ ) ์์์ Semantic Segmentation ๊ฒฐ๊ณผ