[2024 NeurIPS] Adaptive Depth Networks with Skippable Sub-Paths

Hyungseop Lee·2024년 3월 10일
0

Paper Info

https://proceedings.neurips.cc/paper_files/paper/2024/file/3a2d96d2eb2902043c2db705ca03e9a2-Paper-Conference.pdf


Abstract

(이전의 depth adaptive network에 대한 문제점 지적)

  • network depth의 체계적인 adaption은 inference latency를 효과적으로 제어하고
    다양한 devices의 resource 조건을 충족시키는 효과적인 방법이 될 수 있음.
  • 하지만 이전의 depth adaptive network는 왜, 어떤 layer를 건너 뛸 수 있는지에 대한 general principle과 formal explanation을 제공하지 않음.
    따라서 그들의 approach는 generalization되기 어렵고, long and complex training steps이 필요함.

(이전의 depth adaptive network의 아쉬운 점을 개선한 연구를 간략하게 소개)

  • 이 논문에서,
    우리는 with minimal training effort로 다양한 networks에 적용 가능한
    adpative depth networks에 대한 pratical approach를 제시할 것
    임.
  • 우리의 approach에서,
    모든 hierarchical residual stage는 두 개의 sub-paths로 나뉨.
    • the first sub-path는 hierarchical feature learning을 위해 의무적으로 필요함
    • the second sub-path는 skipped되어도 기능을 개선하고, degradation을 최소화하기 위해 optimized된다.
  • 이전의 adaptive network와는 달리,
    우리의 approach는 각 sub-network를 반복적으로 학습하지 않아도 되므로
    훨씬 짧은 training time이 소요
    된다.
  • device에 deploy된 후에는,
    즉시 다양한 depth의 sub-networks를 선택하여
    single model에 대한 다양한 accuracy-efficiency trade-off를 제공할 수 있다.
  • 우리는 제안된 training method가 전체 prediction error를 줄이는 동시에
    선택된 sub-paths를 skipping할 때의 영향을 최소화하는 공식적인 근거를 제시할 것이다.
  • 또한 CNNs과 vision transformers에 대한 generality와 effectiveness도 증명할 것임.

1. Introduction

(관련 연구 1, 문제점 지적)

  • CNNs 과 transformer와 같은 Modern deep neural networks는
    high computational costs로 state-of-the-art performance를 제공한다.
    그래서 그러한 inference capabilities를 다양한 resource constrained device에 적용할 수 있는 연구들이 진행되어져 왔다.
    예를 들자면,
    • compact architecture
    • network pruning
    • weight/activation quantization
    • knowledge distillation
    • ...

      하지만,
      이러한 approach들은 static accuracy-efficiency trade-offs를 제공하므로,
      모든 종류의 resource-constraints가 있는 device에 one single model을 배포하는 것은 불가능하다.

(관련 연구 2, 문제점 지적)

  • 이전에는 neural network의 adaptability를 예측 가능하게 하기 위해
    network의 depths, widths 또는 둘 다의 redundancy를 활용한 시도가 있었다.
    그러나 이전 adaptive network의 주요 difficulty 중 하나는
    hard to train and require significantly longer training time than non-adpative networks.

(이 논문의 아이디어 제시)

  • 우리는 CNNs과 transformers와 같은
    다양한 network에 적용 가능한 adaptive depth networks의 training method를 소개.
    제안된 adaptive depth networks에서는,
    every residual stage는 2개의 sub-paths로 나뉘어 지고,
    그들은 서로 다른 properties로 train
    되어진다.

    • the first sub-paths는 hierarchical feature learning을 위해 필수적으로 있어야 함.
    • the second sub-paths는 skip해도 최소한의 performance degradation을 유발하지 않도록 optimized된다.
      더 자세히 말하면,
      모든 residual stage의 second sub-path는
      skip되어질 때의 performance degradation을 최소화하기 위해
      the first sub-paths(= previous mandatory sub-paths)에서의
      feature distribution을 보존하기 위해 optimized된다.
      (내 생각 : second sub-paths의 역할이 중요한 듯함)
    • training 동안에,
      second sub-paths의 이러한 속성이 skip-aware self-distillation
      (skip 의식적인 자가 증류)를 통해 이루어
      지고,
      이는 Figure 1-(a)에 표시된 것처럼 only one smallest sub-network,
      즉 base-net만 jointly trained되어진다.
      (self distillation에 대한 짧은 이해 : https://www.youtube.com/watch?v=x74pey3rv14)

      • self distillation : model size를 줄이기 위해 model이 자기 자신으로부터 지식을 증류(추출)하여 학습하는 기술.
    • skip-aware self-distillation은
      개별 sub-networks의 명시적인 training이 필요하지 않아
      이전 adaptive network보다 훨씬 짧은 training time을 갖게 된다.
      하지만, 한 번 trained되어진 후에는
      device의 resource 조건을 충족하기 위해 single model로부터 즉시 다양한 depth의 subnet를 선택
      할 수 있다. (Figure 1-(b) 참조)

    • 게다가,
      Figure 1-(c)에서 볼 수 있듯이,
      이러한 다양한 depths의 sub-networks들은 regularization effect로 인해
      개별적으로 trained된 non-adaptive networks보다 우수한 성능을 보인다.

  • Section 3에서,
    더 자세한 architectural parttern and training algorithm을 소개할 것이고,
    skip-aware self-distillation을 사용하여 선택된 sub-paths가
    input features의 level을 최소한으로 변경하면서
    prediction error를 줄이도록 optimized되었음을 formally하게 보여줄 것이다.

  • Section 4에서,
    우리는 CNN 및 vision transformers에서 우리의 adaptive depth networks가
    대응하는 개별 network를 능가하며 실제 inference acceleration and energy-saving을 달성한다는 것을 실험적으로 증명할 것이다.


2. Related Work

Adaptive Networks :

(skip)

Residual Block with Shortcuts :

(skip)


3. Adpative Depth Networks

  • 먼저 adpative depth networks의 architecture와 training details를 소개하겠다.
    그리고나서, depth adaptation이 최소한의 performance degradation과 training effort로 어떻게 달성될 수 있는지에 대한 이론적 근거를 discuss할 것임.

3.1. Architectural Pattern for Depth Adaptation

  • 일반적으로 ResNets과 Swin transformers와 같은 hierarchical residual networks에서,
    ss-th residual stage는 LL개의 동일한 residual blocks로 이루어져 있다.
    그리고 그 ss-th residual stage는 다음과 같이
    input features h1sh_1^s를 transform하여 output feature hsh^s를 생성한다.

  • residual function Fl(l=1,...,L)F_l(l=1, ..., L)은 traditional compositional networks와 마찬가지로
    hierarchical features를 학습한다.
    이전 문헌(Jastrzebski et al., 2018; Greff et al., 2016)에 따르면, 
    residual function은 같은 feature level에 있는 이미 학습된 features들을
    세밀하게 조정하는 기능을 학습하는 경향이 있다
    고 한다.

    만약 residual block이 대부분의 경우 input features의 수준을 변경하지 않으면서도
    feature를 세밀하게 조정한다면,
    test time에 block을 drop해도 residual network의 성능에는 큰 영향을 미치지 않는다.

    그러나 일반적인 residual networks에서는 대부분의 residual block이
    새로운 level의 feature를 학습하면서도 feature를 세밀하게 조정하기 때문에,
    test time에 임의로 residual block을 삭제하면 performance가 매우 크게 떨어
    진다.

    따라서 우리는 몇개의 selected residual blocks이
    feature refinement에 더 집중하도록 명시적으로 training 중에 격려된다면,
    이러한 block들은 test time에 computation을 절약하기 위해 skip할 수 있으며,
    prediction accuracy의 손실도 미미해질 것이라는 가설을 세웠다.

  • 이를 위해,
    우리는 adaptive depth networks를 위한 architectural pattern을 제안한다.
    이 pattern에서 각 residual stage는 Equation 1과 Figure 2.에 나와 있는 것처럼
    두 개의 연속적인 sub-paths 또는 FbasesF^s_{base}FskippablesF_{skippable}^s로 나뉘어진다.


3.2. Skip-Aware Self-Distillation

Algorithm 1은 skip-aware self-distillation라고 불리는
우리의 training method를 보여주는데,
여기서 Equation 2가 loss function에 포함되어 있으며,
MM의 largest sub-networks(=super-net)와 smallest sub-networks(=base-net)가
jointly(함께) trained되어진다.
Algorithm 1에서,
제안된 adaptive depth networks인 MM은 'skip'이라는 추가의 argument를 받는데,
이는 그들의 skippable sub-paths가 skipped되어지는 residual stage를 제어한다.
예를 들어,
MM에 4개의 residual stages가 있다면,
'skip=[True, True, True, True]'을 전달하여 base-net가 선택되어진다.
(내 이해가 맞다면, 'skip=[False, False, False, False]'를 전달하면 super-net이 되는 것)

* MM을 training하기 위해서 2개의 sub-networks만이 관련되며,
총 훈련 시간은 두 개의 sub networks를 개별적으로 훈련하는 것보다 크지 않음.
test time에는 residual stage에서 sub-paths를 systemical하게 skipping하여
다양한 depth의 sub-networks를 실시간으로 선택할 수 있다.
예를 들어, 4개의 residual stage를 가진 network의 경우,
skip argument를 통해 24=162^4 = 16개의 parameter-sharing sub-networks를 만들 수 있다.
(True or False 경우의 수가 4개의 residual stage에 있으니)


3.3. Formal Analysis of Skippable Sub-Paths

Formal Analysis

  • FskippablesF^s_{skippable}의 residual block들이 identity function을 학습한다면,
    lossbaseloss_{base}에 있는 DKL(hsupershbases)D_{KL}(h^s_{super} || h^s_{base})은 너무 쉽게 minimized될 수 있다.

    하지만, super-net이 losssuperloss_{super}과 함께 jointly train되기 때문에,
    FskippablesF^s_{skippable}은 간단하게 identity function이 될 수 없다.
    이를 Taylor expansion([43])을 통해 살펴볼 수 있다.
    • Taylor expansion에 대한 이해를 도운 참고자료 :
      간단히 설명하자면, 어떠한 함수 f(x)f(x)의 점 aa 주변에 대한 함수를 근사화하고 싶을 때,
      다항식을 이용하여 근사할 수 있다.
      예를 들어, f(x)f(x)를 점 a에서 전개하면 다음과 같이 근사할 수 있다.
  • 우리의 adaptive depth network에 대해서,
    super-net을 training하기 위해 사용되는 loss function LL
    Taylor expansion에 의해 다음과 같이 근사할 수 있다.
    • Taylor expansion 일반화식
    • loss function LL for training super-net approximated with Taylor expansion :

      Eq 4.에서 first order term만 남고,
      나머지 2계도함수 이상의 high order terms들은 O(FL(hLs))O(F_L(h^s_L))로 묶음.

      만약 FL(hLs)F_L(h_L^s) 가 작은 값이라면, O(FL(hLs))O(F_L(h^s_L))의 high-order term들은 무시될 수 있다.
      일반적인 residual networks에서는 각 layer가 새로운 feature를 학습하도록 훈련되며, 이에 대한 제약이 없기 때문에 작은 값을 갖는다는 보장이 없다.
      그런데 adaptive depth network에서, FskippablesF^s_{skippable}의 residual들은
      self-distillation strategy에 의해 작읍 값을 갖도록 강요된다. (Figure 3에 empirical evidence)
      그렇기 때문에,
      매우 작은 값을 갖는 O(FL(hLs))O(F_L(h^s_L)) term은 approximation에서 무시될 수 있다.
      따라서 다음과 같은 approximation을 얻는다 :
      Eq5.에서,
      training 동안에 loss L(hsupers)L(h^s_{super})를 minimizing하기 위해서
      Fj(hjs)F_j(h^s_j)(L(hjs))(hjs)\frac{ \partial(L(h_j^s)) }{ \partial(h^s_j) }의 negative half space로 유도(Eq.6)하여
      Fj(hjs)F_j(h^s_j)(L(hjs))(hjs)\frac{ \partial(L(h_j^s)) }{ \partial(h^s_j) }의 dot product를 minimize한다.
      이는 FskippablesF_{skippable}^s의 모든 residual function들이
      gradient descent와 유사한 효과를 가진 함수를 학습하도록 optimize된다는 것을 의미한다.다시 말해,
      skippable sub-paths의 residual function들은 hbasesh_{base}^s의 feature distribution을 유지하면서
      inference 도중 반복적으로 loss L(hbases)L(h_{base}^s)를 줄인다.
      이 결과를 고려하면,
      우리의 architecture pattern과 self-distillation strategy를 통해
      FskippablesF_{skippable}^s의 layer는 hbasesh_{base}^s의 distribution을 최소한으로 변경하면서
      더 나은 inference accuracy를 위해 input feature hbasesh_{base}^s를 반복적으로 refine하는 함수를 학습한다고 추측
      할 수 있다.

3.4. Skip-Aware Batch Normalization

  • 원래,
    BN은 non-adaptive networks에서 features를 normalizing하여
    training 동안에 internal covariate shift를 다루기 위해 제안되었다.

  • 하지만 our adaptive depth networks에서는,
    inference 중에 서로 다른 sub-networks가 선택될 경우,
    mandatory sub-paths에서 internal covariate shift가 발생
    할 수 있다.
    potential internal covariate shifts를 처리하기 위해서,
    mandatory sub-paths에는 skip-aware BNs이라고 불리는
    switchable BN operators가 사용
    된다.

    • 예를 들어,
      각 residual stage마다 mandatory sub-path에는 2개의 BN set가 있으며,
      그것들은 skippable sub-path가 skip되는지 여부에 따라 전환된다.
  • 이전의 adaptive network는 NN parameter-sharing sub-network를 지원하기 위해
    모든 layer에 대한 NN sets of switchable BNs이 필요했다.
    our adaptive depth network는 지원되는 sub-network의 수에 관계 없이
    2개의 switchable BNs만 있으면 된다.

  • Transformer에서는 BN 대신에 Layer Normalization(LN)을 사용함.
    그래서 switchable BNs 대신에 switchable LN operator를 적용.


4. Experiments

4.1.

profile
Efficient Deep Learning

0개의 댓글