[Review] Is Normalization Indispensable for Training Deep Neural Networks?

Hwan Heo·2021년 11월 27일
1
post-custom-banner

Normalization 없이 학습을 진행했을 때 performance degradation 을 지양할 수 있는가에 대한 discussion 과 이를 바탕으로 설계한 no-normalization architecture 인 RescaleNet 을 제시하였다.
https://papers.nips.cc/paper/2020/file/9b8619251a19057cff70779273e95aa6-Paper.pdf (NeurIPS2020 Oral)

1. Introduction

이 논문은 batch normalization 이 대부분의 DNN 학습 과정에서 중요한 역할을 하고 있는 최근 DNN 에서,

  1. Normalization 없이 stable 한 학습을 진행할 수 있는가
  2. Normalization 없이 학습을 진행했을 때 performance degradation 을 지양할 수 있는가

에 대한 연구, 특히 질문 (2)에 대한 연구를 진행하여 이를 바탕으로 한 새로운 normalization-free architecture RescaleNet 를 제시하였다.

2. Preliminaries

Residual Network 로 대표되는 supervised feature extractor 의 문제점을 두개로 지적하면서 논문을 서술한다.

  1. Exploding Variance
  2. Dead ReLU

로 대표되는 두가지 문제점인데, 특히 (1) 에 대해서 normalization 을 사용하지 않을 때 성능 저하가 눈에 띈다고 한다.

2.1. Exploding Variance

DNN training 은 몇가지 정해진 Initialization 방법에 따라 weight initialization 한 후 학습을 진행하고, 그 중 ResNet 에 쓰이는 방법은 'He Initialization' 이다.

He Initialization 을 통해, residual block 의 input 과 output 은 그 variance 가 같게 된다. 즉,

Var(Fl(xl1))=Var(xl1)\text{Var} ( \mathcal F_l (x_{l-1} ) ) = \text{Var}(x_{l-1} )

이때 F\mathcal F 는 residual block 을 의미한다.

이때 normalization 이 없다면 다음과 같은 문제가 발생한다.

Var(xl)=Var(xl1+Fl(xl1))Var(xl1)+Var(Fl(xl1))=2Var(Fl(xl1))\begin{aligned} \text{Var}(x_l ) &= \text{Var}(x_{l-1} + \mathcal F_l (x_{l-1}) ) \\ &\simeq \text{Var}(x_{l-1}) + \text{Var}(F_l (x_{l-1})) = 2 \text{Var}(F_l (x_{l-1})) \end{aligned}

input 과 residual block 간의 correlation 이 0이 가깝기 때문이며, 이에 따라 resiual block 을 통과할 때마다 variance 가 이전 block 의 2배가 되어, 종래에는 exploding 하게 될 것이다.

2.2. Dead ReLU

Dead relu 는 initialization 을 통해서 극복할 수 없다고 한다. 한 번 activation 이 0이 된 이후에는 다시 recover 되지 않고, VGG model 등의 대표적인 DNN model 에서 20 layer 에 40% 정도의 neuron 이 deactivate 된다고 실험을 통해서 보이고 있다.

Figure 는 Weight initialization 이후 training process 를 진행하지 않고 activation 여부를 측정한 실험인데, layer 가 깊어질수록 deactivate 되는 뉴런의 수가 급증하는 것을 볼 수 있다. 여기서 all pos, all neg 는 neuron entry 가 모두 같은 부호일 경우 (즉 dead activation) 를 나타내고, non-linear 는 반대이다.

  • VGG architecture 에 대한 figure 만을 제공하고 있는데, Residual architecture 에서 BN 을 대체하는 구조를 제시하는 논문이니만큼 ResNet 을 비롯하여 다른 architecture 에 대해서도 Dead ReLU 비율에 대한 실험이 궁금하다.

3. RescaleNet

3.1. RescaleNet for Residual Learning

exploding/vanishing variance 문제를 해결하기 위해, residual 구조에 multiplier term 을 추가한 다음과 같은 'Rescale' Network 를 제시한다.

xk=αkxk1 + βkFk(xk1)x_k = \alpha_k x_{k-1} \ + \ \beta _k \mathcal F_k (x_{k-1} )

He initialization 을 통해서 Var[F(xk1)]Var[xk1]\text{Var}[\mathcal F(x_{k-1}) ] \simeq \text{Var}[x_{k-1}] 이므로,

Var[xk]=αk2Var[xk1] + βk2Var[Fk(xk1)]=αk2Var[xk1] + βk2Var[xk1]=(αk2+βk2)Var[xk1]\begin{aligned} \text{Var}[x_k ] &= \alpha_k^2 \text{Var}[ x_{k-1} ]\ + \ \beta _k^2 \text{Var} [\mathcal F_k (x_{k-1} )] \\ &= \alpha_k^2 \text{Var}[ x_{k-1} ]\ + \ \beta _k^2 \text{Var} [\mathcal x_{k-1} ] \\ &= (\alpha_k^2 + \beta _k^2)\text{Var} [\mathcal x_{k-1} ] \end{aligned}

임을 알 수 있다. 이때 stable variance 를 유지하기 위하여, αk2+βk2=1\alpha_k^2 + \beta _k^2 = 1 이라는 constraint 를 도입***하여 stable 한 variance 를 유지할 수 있다.

이제 Rescaled-Residual connection 을 전개하여, 우리는 다음과 같은 식을 얻을 수 있다.

xL=(i=1Lαi)x0 + k=1Lβki=k+1LαiFk(xk1)x_L = ( \prod _{i=1} ^ L \alpha _i ) x_0 \ + \ \sum_{k=1} ^{L} \beta_k \prod_{i=k+1}^L \alpha _i \mathcal F_k (x_{k-1} )

이때 optimal coefficient 는 각기 다른 residual block 에 동등한 weight 를 부과하도록 할 것이다. 즉, 다음과 같다.

kkk, βki=k+1Lαi = βki=k+1Lαi\forall _k k\neq k', \ \beta_k \prod_{i=k+1}^L \alpha _i \ = \ \beta_{k'} \prod_{i=k'+1}^L \alpha _i

일반성을 잃지 않고 k=k+1k' = k+1 일 때 식을 전개해보면,

βki=k+1Lαi = βk+1i=k+1+1Lαiβkαk+1=βk+1\beta_k \prod_{i=k+1}^L \alpha _i \ = \ \beta_{k+1} \prod_{i=k+1+1}^L \alpha _i \\ {} \\ \rightarrow \beta_k \alpha _{k+1} = \beta_{k+1}

이고, , αk2+βk2=1\alpha_k^2 + \beta _k^2 = 1 이라는 제약조건에 따라

1βk+121βk2=1{1 \over \beta_{k+1}^2} - {1 \over \beta_{k}^2} = 1

임을 알 수 있다.

3.2. Learable Multiplier

위 제약조건들을 통해서 Rescaled Residual Connection 을 다음과 같이 정의할 수 있다.

xk=k1+ck+cxk1 + 1k+cFk(xk1)x_k = \sqrt{ {k-1+c} \over k+c } x_{k-1} \ + \ \sqrt{ {1} \over k+c } \mathcal F_k (x_{k-1} )

여기서 residual block 에 대한 weight 값이 클수록 깊은 layer 위주로 학습이 이루어지고, 작을수록 얕은 layer 위주로 학습이 이루어질 것임을 알 수 있다.

따라서 저자들은 learnable multiplier mkm_k를 도입한 다음과 같은 alternative objective 를 최종적인 Rescaled Residual Connection 으로 제시한다.

xk=k1+ck+cxk1 + mkcFk(xk1)x_k = \sqrt{ {k-1+c} \over k+c } x_{k-1} \ + \ { {m_k} \over \sqrt{ c} } \mathcal F_k (x_{k-1} )
  • 이 learnable multiplier 는 자동적으로 학습 초기 단계에는 낮게 설정되어 shallow network 에서의 feature extracting 품질을 높이고, 학습 말기 단계로 갈수록 크게 설정되어 deep layer 에서의 세부적인 NN 학습을 야기한다.

  • 만약 이 mkm_k 값이 학습이 진행됨에 따라 매우 커진다면, stable variance 을 유지할 수 없게 된다.

  • 저자들은 이를 convolution weights 가 training 과정에서 shrink 하고 (weight decay 때문) 이러한 weight shrinking 이 increasing learnable multiplier 와 상쇄되기 때문에 전체 network 의 성능에 큰 영향을 미치지 않는다고 주장하였다.
    (이에 관련한 디테일한 실험이 있었으면 좋을 것 같다)

3.3. Bias Initialization

Bias initialization 에 대해서도 적절한 setting 이 필요한데, 저자들은 이를 우선 matrix multiplication 전에 더하는 것으로 설정하였다. 즉,

y=W(x+b)y = W(x+b)

또한 이를 data-dependent 하게 (Negative of the mean of the 1st mini-batch) initialization 했다고 한다.

  • 저자들은 Bias before matrix-multiplication 이 optimize 하기 쉽고, data-dependent init 은 dead ReLU problem 에 대한 효과적인 대응책이 될 수 있다고 주장한다.

  • 이는 mean of the 1st mini-batch 이 모든 neuron 에 대해서 substracted 될 것이므로, 마치 BN 에서 평균을 0으로 맞춰주는 것과 비슷한 효과가 일어나는 것이 아닌가 생각한다.

4. Experiments

4.1. Image Classification



4.2. Object Detection & Segmentation

4.3. Video Classification & Machine Translation

profile
기타치는AI Researcher
post-custom-banner

0개의 댓글