Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results[.,2017]

Sungchul Kim·2021년 12월 29일
0

Knowledge-distillation

목록 보기
2/4
post-thumbnail

Summary

본 연구에서는 기존에 알려져 있는 Temporal ensembling 방법에서 한 단계 나아가 Mean teacher방법을 제안합니다. Mean teacher는 model의 weight를 평균 내는 방법인데, 적은 수의 label만 가지고 이전 방법에 비해 좋은 성능을 보였습니다. SVHN에 대하여 약 250개의 label만 가지고 4.35%의 error rate를 보였습니다. 또한 Mean teacher와 Residual Network(Resnet)과 결합하였을 때, 좋은 성능을 보였다고 합니다. (CIFAR-10 with 4000 labels : 6.28%)

Related work

neural network를 regularization하기 위해 여러 노이즈 방법들이 제안되었습니다. Semi-supervised-learning에서도 노이즈 방법들을 적용하게 되는데 이를 Consistency training이라고 합니다.
Consistency training이란, 입력 데이터나 은닉층(hidden)에 약간의 노이즈를 가했을 때, 모델의 예측이 바뀌지 않도록 모델을 regularization하는데 초점을 둔 방법입니다.

⨿ model\amalg \ \mathbf{model}

첫번째, ⨿ model\amalg \ \mathbf{model}입니다. 동작하는 과정은 아래와 같습니다.

  1. Input image xi\mathbb{x_{i}}에 stochastic augmentation을 적용함.
  2. Augmented image xi\mathbb{x_{i}^{'}}에 dropout을 적용함.
  3. Model을 통과하면 2개의 output이 나오게 됨. (zi\mathbb{z}_{i}, z~i\tilde{\mathbb{z}}_{i})
    즉 동일한 이미지 xi\mathbb{x_{i}}를 통해 2개의 output을 출력하게 됩니다.
  4. zi\mathbb{z}_{i}은 hard label(yi\mathbb{y}_{i})과 cross-entropy계산, 2개의 output(zi\mathbb{z}_{i}, z~i\tilde{\mathbb{z}}_{i}) squared difference계산.
    논문에서는 squared difference를 mean squared error로 명시함.
  5. Supervised loss term과 unsupervised loss term을 더함 → loss 최소화.
    w(t)w(t)는 weighted sum할때, Unsupervised Loss의 가중치를 결정하는 함수입니다. training step t가 증가함에 따라 w(t)w(t)는 증가하게 됩니다. unsupervised loss term은 아래와 같습니다.
LSSL=w(t)1BiBziz~i2L_{SSL} = w(t) \frac{1}{|B|} \sum_{i \in B} ||\mathbb{z}_{i} - \tilde{\mathbb{z}}_{i}||^{2}

위 방법은 image에 stochastic augmentation, dropout을 적용하여 Consistency training 하는 데 초점을 두었으나, 하나의 network를 기반해 동작하기 때문에 노이즈가 심하다는 단점이 존재합니다.

Temporal ensembling\mathbf{Temporal \ ensembling}

두번째, Temporal ensembling\mathbf{Temporal \ ensembling} 입니다. 동작하는 과정은 아래와 같습니다.

  1. Input image xi\mathbb{x_{i}}에 stochastic augmentation과 dropout을 적용
  2. Output zi\mathbb{z}_{i}와 hard label(yi\mathbb{y}_{i})간의 cross-entropy계산
  3. zi\mathbb{z}_{i}, z~i\tilde{\mathbb{z}}_{i}간의 squared difference계산
    논문에서는 squared difference를 mean squared error로 명시

그런데 여기서 z~i\tilde{\mathbb{z}}_{i}는 무엇이고, zi\mathbb{z}_{i}는 언제 사용되는 변수인지 pseudo code를 통해 확인해 보았습니다.

Pseudo code\mathbf{Pseudo \ code}

ZZ는 ensemble predictions, z~\tilde{\mathbb{z}}는 target vector로 학습하기 전에 초기화시킵니다.

  1. 한 epoch이 끝날때마다 ZZ는 아래 수식과 같이 update됩니다. (이때 z\mathbb{z}는 마지막 batch의 output.)

    • ZαZ+(1α)zZ ← \alpha Z + (1-\alpha)\mathbb{z}
  2. 앞서 update한 ZZ로 target vector(z~\tilde{\mathbb{z}})를 update시킵니다. (bias correction)

    • z~Z/(1αt)\tilde{\mathbb{z}} ← Z / (1-\alpha^{t})
  3. Temporal ensembling은 과거 모델의 결과와 현재 모델 결과를 ensemble → prediction vector

그러나 Temporal ensembling의 경우, 각 target은 epoch당 한번 update되므로 학습이 느리다는 단점이 존재합니다.

Bias correction이란?

  • vt=βvt1+(1β)θtv_{t} = \beta v_{t-1} + (1-\beta) \theta_{t}

EMA(exponential moving average) term에서 초기값에 따라 실제 데이터의 값을 잘 따라가지 못하는 경우도 있는데, 실제 데이터와 유사하도록 보정해주는 방법입니다.

예를 들어, 초기값 v0v_{0}을 0이라 가정해 보도록 하겠습니다. (이때 β\beta = 0.98)

v1=0.98×0+0.02θ1=0.02θ1v_{1} = 0.98 \times 0 + 0.02 \theta_{1} = 0.02 \theta_{1}

v2=0.98×v1+0.02θ2=0.0196θ1+0.02θ20v_{2} = 0.98 \times v_{1} + 0.02 \theta_{2} = 0.0196 \theta_{1} + 0.02 \theta_{2} \approx 0

이처럼 v2v_{2}이 0으로 수렴하게 되는데, 이를 보정해주는 작업이 bias correction입니다.

수식은 아래와 같습니다.

  • vt=βvt1+(1β)θtv_{t} = \beta v_{t-1} + (1-\beta) \theta_{t}

위의 수식에 적용을 해보자면, 앞서 구한 v2v_{2}0.0196θ1+0.02θ20.0196 \theta_{1} + 0.02 \theta_{2}이고 1β2=1(0.98)2=0.03961-\beta^{2} = 1-(0.98)^{2} = 0.0396이므로 v2v_{2}는 아래와 같습니다.

  • v20.0396=0.0196θ1+0.02θ20.0396\frac{v_{2}}{0.0396} = \frac{0.0196 \theta_{1} + 0.02 \theta_{2}}{0.0396}

그럼 v_{2}보다 큰 값을 만들 수 있고 초기값 부근에서 값을 보정받을 수 있습니다.

Method

위에서 설명드린 2가지 방법은 각각 한계점을 가지고 있었습니다. 이러한 한계점을 보완하기 위해 저자는 Mean teacher라는 방법을 제안합니다. Mean teacher는 model의 weight에 대해 weighted average하는 방법을 말합니다.

위에서 설명드렸던 2개의 방법과 저자가 제안한 방법의 제일 큰 차이는 teacher model, student model을 사용한다는 점입니다. (위의 2가지 방법은 single network로써, parameter share)

Train Process

  1. Image가 student model, teacher model을 거침.
  2. Student/teacher model에 각각 noise(η\eta, η\eta^{'})를 줌. (dropout)
  3. Student/teacher model의 prediction을 가지고 Loss계산
  4. Student model의 weight(θ\theta)를 이용하여 teacher model의 weight(θ\theta^{'}) update.

Loss

  • Student model과 hard label간의 classification cost(cross-entropy)
  • Teacher model과 student model간의 consistency cost(mean-squared error)
Consistency cost=LSSL=w(t)1BiBf(x,θ,η)f(x,θ,η)2Consistency \ cost =L_{SSL} = w(t) \frac{1}{|B|} \sum_{i \in B} || f(\mathbb{x},\theta^{'},\eta^{'}) - f(\mathbb{x},\theta,\eta)||^{2}

Notation

  • η\eta, η\eta^{'}: sampling noise
  • θ\theta, θ\theta^{'} : teacher/student network의 weight
  • Consistency cost는 student/teacher network의 prediction의 expected distance.
    (저자는 Consistency cost를 mse loss로 define)
  • classification cost + consistency cost를 줄이는데 초점을 둠.

Weight update process

teacher model의 weight는 student model의 weight로부터 update

  • Exponential moving average (EMA)

  • θt=αθt1+(1α)θt\theta^{'}_{t} = \alpha \theta^{'}_{t-1} + (1-\alpha) \theta_{t}

Mean teacher versus Temporal ensembling

  • temporal ensembling은 single network, Mean teacher는 teacher/student network 분리.
  • temporal ensembling과 다르게 Mean teacher는 step마다 teacher network의 weight를 update시켜주기 때문에, teacher/student 간의 빠른 피드백을 주어 접근하는 방식입니다.
  • Mean teacher방법은 large dataset, on-line learning에 적합함.
  1. x^=λxi+(1λ)xi\mathcal{\hat x } = \lambda \mathcal{x_{i} } + (1-\lambda)\mathcal{x_{i} }

  2. y^=λyj+(1λ)yj\mathcal{\hat y } = \lambda \mathcal{y_{j} } + (1-\lambda)\mathcal{y_{j} }

Experiments

본 연구에서는 SVHN, CIFAR-10에 대해 실험을 진행하였습니다.

  • Datasets
    • SVHN train/test : 73257/26032
    • CIFAR-10 train/test : 50000/10000
    • Imagenet 2012
  • Architecture : 13-layer ConvNet, Resnet
  • Regularizations : random translations, horizontal flips, gaussian noise, dropout
  • T : ramp up its weight from 0 to its final value during the first 80 epochs
  • Consistency cost = mean squared error

위 테이블은 SVHN에 대하여 여러 semi-supervised-learning 방법론을 적용했을 때의 error rate를 보여줍니다. Mean Teacher의 경우 label을 조금만 써도 error rate가 낮음을 알 수 있습니다.
(Temporal Ensembling, \amalg model의 경우 label 1,000개 사용했을 때랑 비슷 )

위 테이블은 CIFAR-10, ImageNet에 대하여 Mean teacher를 적용했을 때에 대한 error rate를 보여줍니다. 이때 CIFAR-10의 경우 4,000개의 label, Imagenet의 경우 10%의 label만 사용하였습니다.
State of the art는 VAT(Virtual Adversarial Training)이었고, 앞 실험과 다르게 ConvNet뿐만 아니라 ResNet(12-block, 26-layer)을 Mean teacher와 결합하였습니다.

  • ResNet + Mean Teacher → 좋은 성능을 보임.

Conclusions

  • Semi-supervised learning에서 Temporal Ensembling방법은 큰 강점을 보임
    • 그러나 epoch당 한번 update됨을 단점이 존재
  • Large dataset, on-line learning에 적합한 Mean teacher를 제안
  • Consistency regularization은 teacher-generated targets에 dependent

Reference

  • Temporal Ensembling for Semi-Supervised Learning[2016]
  • Andrew Ng의 Bias correction
profile
김성철

0개의 댓글