[2021 NeurIPS] DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Hyungseop Lee·2024년 10월 8일
0

Rao, Yongming, et al. "Dynamicvit: Efficient vision transformers with dynamic token sparsification." Advances in neural information processing systems 34 (2021): 13937-13949.


Abstract

  • Vision Transformer의 attentions은 sparse하다.
    우리는 vision transformers에서 final prediction이 a subset of most informative tokens만으로 이루어지며, 이는 정확한 image recognition을 위해 충분하다는 것을 관찰했다.
    이 관찰을 바탕으로, input에 따라 redundant tokens을 점진적으로 prune하는 dynamic token sparsification framework를 제안한다.
    구체적으로, 현재 feature를 기반으로 각 token의 importance score를 추정하기 위한 lightweight prediction module을 design했다.
    이 module은 다양한 layers에 추가되어 redundant token을 계층적으로 prune한다.
    prediction module은 end-to-end 방식으로 optimize하기 위해, 우리는 attention masking strategy를 제안하여 token의 상호작용을 차단함으로써 해당 token을 차별적으로 prune한다.
    self-attention의 특성 덕분에, unstructured sparse tokens도 여전히 HW friendly하여 실제 속도 향상을 쉽게 얻을 수 있다.
    input token의 66%를 계층적으로 pruning함으로써, 우리의 방법은 31%~37%의 FLOPs를 줄이고 throughput을 40% 이상 향상시키는 동시에 다양한 ViTs에서 accuracy 저하는 0.5% 이내로 유지된다.
    dynamic token sparsification framework를 장착한 DynamicViT model은 ImagetNet에서 SOTA CNNs 및 ViTs와 비교하여 경쟁력 있는 complexity/accuracy trade-offs를 달성할 수 있다.
    (코드: https://github.com/raoyongming/DynamicViT)

1. Introduction

  • CNN-type networks의 acceleration의 일반적인 방법 중 하나는 less importance filters를 prune하는 것이다.
    ViT와 그 variants에서 input을 처리하는 방식, 즉 input image를 여러 개의 independent patches로 분할하는 방식은
    acceleration을 위해 sparsity를 도입할 수 있는 또 다른 방법을 제공한다.
    이는 많은 token이 final prediction에 거의 기여하지 않는다는 사실을 고려할 때, input instance의 less importance token을 제거할 수 있음을 의미한다.
    이러한 방법은 the token sequence of variable length을 input으로 받을 수 있는 transformer-like models에서만 가능하며,
    unstructured pruned input이 self-attention module에 영향을 주지 않기 때문에 가능하다.
    반면, CNN에서는 pixel의 일부를 제거해도 unstructured neighborhood 때문에 병렬 처리를 통한 convolution의 가속화가 어렵다.
    CNN의 hierarchical architecture가 다양한 vision tasks에서 model efficiency를 향상시켰기 때문에, 우리는 vision transformers에서 downsampling strategy를 탐구하고자 한다.
    (실험 결과, unstructured sparsification가 structural downsampling보다 ViTs의 성능을 더 향상시킬 수 있음을 보여준다)
    우리의 기본 아이디어는 Figure 1에 나와 있다.

  • 본 연구에서는 dynamic way로 pruning할 token을 결정하기 위해 lightweight prediction module을 사용하는 DynamicViT를 제안한다.
    특히 각 input instance에 대해 prediction module은 어떤 token이 uninformative(정보성이 낮고) and need to abandoned(버려져야 하는지)를 결정하는 customized binary decision mask를 생성한다.
    이 module은 ViT의 여러 layer에 추가되며, 각 prediction module 이후 점진적으로 더 많은 token을 pruning하여 계층적으로 sparsification을 수행한다.
    특정 layer 이후 pruned된 token들은 이후 feed-forward 과정에서 다시 사용되지 않는다.

    이 lightweight module이 도입하는 추가의 computatioanl overhead는 uninformative token을 제거함으로써 절약되는 computational overhead에 비해 매우 작다.

  • preidction module은 ViT backbone과 함께 end-to-end manner로 최적화될 수 있다.
    이를 위해 두 가지 specialized strategies를 채택했다.

    • 첫 번째는 Gumbel-Softmax를 사용하여 distribution에서 sampling하는 non-differentiable problem을 극복하여
      end-to-end training이 가능하게 하는 것
      이다.
    • 두 번째는 학습된 binary decision mask를 사용하여 uninformative tokens을 제거하는 방법에 대한 것이다.
      binary decision mask에서 0의 요소가 instance마다 다르기 때문에 training 중에는 병렬 계산이 불가능해질 수 있으며,
      prediction module이 token을 유지할지 여부에 대한 probability distribution을 계산해야 하므로
      이는 back-propagation을 방해할 수 있다.
      또한 버려진 token은 attention matrix 계산에 영향을 미치기 때문에 zero vectors로 설정하는 것도 좋은 방법은 아니다.
      따라서 우리는 binary decision mask를 기반으로 버려진 token과 다른 모든 token 간의 연결을 차단하는 'attention masking'이라는 전략을 제안한다.
      이를 통해 위에서 설명한 어려움을 극복할 수 있다.
      또한 ViT의 original training objective를 수정하여 특정 layer 이후 pruned된 token의 비율을 제한하는 항을 추가했다.
      inference 단계에서는 더 이상 differentiable을 고려할 필요가 없기 때문에 각 input instance에 대해 특정 layer 이후에 고정된 양의 token을 직접 버릴 수 있으며, 이는 inference 속도를 크게 가속화할 수 있다.
  • 우리는 DeiT [25]와 LV-ViT [16]을 백본으로 사용하여 ImageNet에서 이 방법의 효과를 입증했다.
    우리의 DynamicViT는 transformer-like model의 가속화를 위해 sparsity in space을 보여주며,
    향후 transformer-like models의 가속화에 새로운 길이 열리길 기대한다.

2. Related Work

Vision transformers.

  • ViT는 처음으로 non-overlapping image patches for the image cls taks를 위해 transformer architecture를 바로 적용했던 연구.

  • DeiT는 convolution-free transformer에 대한 많은 training techniques를 제안했다.

  • LV-ViT는 token labeling이라고 불리는 a new training objective를 제안함으로써 성능을 향상시킴

Model acceleration.

  • Model acceleration techniques은 edge devices에서 DL model을 deployment할 때 매우 중요하다.
    Deep model의 inference speed를 가속화하기 위해 사용할 수 있는 다양한 기술이 있으며,
    여기에는 quantization, pruning, low-rank factorization, knowledge distillation 등이 있다.
    Transformer model의 inference speed를 가속화하기 위한 여러 연구도 존재한다.
    예를 들어, TinyBERT는 transformer의 inference를 가속화하기 위해 distillation methdo를 제안했다.
    Star-Transformer는 fc 구조를 start-shaped topology로 대체하여 quadratic space and time complexity를 linear로 줄였다.

  • 우리의 방법은 CNN에서의 neuron이 아닌 less importance tokens을 pruning하여 informative patches을 활용하는 것을 목표로 한다.


3. Dynamic Vision Transformers

3.1 Overview

  • 우리의 DynamicViT의 전체 framework는 Figure 2에 있다.
    DynamicViT는 backbone으로 사용되는 일반적인 vision transformer와 여러 prediction modules로 구성된다.
    backbone network는 ViT [8], DeiT [25], LV-ViT [16]와 같은 다양한 vision transformer로 구현될 수 있다.
    prediction module은 token을 dropping/keeping할 probabilities를 생성하는 역할을 한다.
    token sparsification은 전체 network를 통해 특정 위치에서 계층적으로 수행된다.
    예를 들어, 12-layer transformer가 주어졌을 때, 우리는 4번째, 7번째, 10번째 block 앞에서 token sparsification을 수행할 수 있다.
    Training 중에는, 새로 고안된 masking strategy 덕분에 prediction module과 backbone network를 end-to-end로 최적화할 수 있다.
    Inference 중에는, predefined pruning ratio와 prediction module에서 계산된 score에 따라 most informative tokens만 선택하면 된다.

3.2 Hierarchical Token Sparsification with Prediction Modules

  • DynamicViT의 중요한 특징은 token sparsification을 계층적으로 수행한다는 것이다.
    즉, computation 과정 동안에 uninformative tokens이 점진적으로 버려진다.
    이를 달성하기 위해, 우리는 각 token이 drop될지 keep될지 결정하는 binary decision mask D^{0,1}N\hat{D} \in \{0, 1\}^N를 사용했다.
    여기서 N=HWN=HW은 #patch_embeddings이다. (simplicity를 위해 class token을 생략했다. the decision for class token is alwasys "1")
    우리는 decision mask의 모든 elements를 1로 초기화하고 점진적으로 mask를 update시켰다.
    prediction module은 현재의 decision D^\hat{D}와 tokens xRN×Cx \in \R^{N \times C}를 input으로 갖는다.
    우리는 우선 tokens들을 MLP를 이용하여 project했다.여기서 CC'은 smaller dimension이 될 수 있고, 우리는 C=C/2C' = C/2을 사용했다.
    비슷하게, 우리는 global feature를 다음과 같이 계산할 수 있었다 :
    여기서 Agg()는 모든 존재하는 tokens에 대한 information을 aggregate하는 function이고, average pooling으로 간단히 수행될 수 있다 :The local feature는 특정 token의 정보를 encoding하며, global feature는 whole image의 context를 포함하기 때문에 둘 다 유의미하다.
    따라서 우리는 local and global features를 결합하여 local-global embedding을 얻고,
    이를 또 다른 MLP에 입력하여 token을 drop/keep할지에 대한 probabilityes를 예측한다 :

여기서 πi,0\pi_{i, 0}ii-th token에 대한 probability of dropping을 나타내고, πi,1\pi_{i, 1}은 probability of keeping을 나타낸다.
그러면 우리는 π\pi로부터 sampling함으로써 current deicsion DD을 생성하고, D^\hat{D}를 update할 수 있다.
여기서 \odot는 Hadamard product(element-wise product)이다.
token은 한 번 dropped된다고 정해졌으면, 그 다음에는 절대 쓰이지 않는다.

3.3 End-to-end Optimization with Attention Masking

  • 비록 우리의 목표가 token sparsification을 수행하는 것이지만, training 중에 이를 실제로 구현하는 것은 간단하지 않다는 것을 발견했다.

  • 첫째, binary decision mask DD를 얻기 위해 π\pi로부터 sampling하는 것은 non-differentiable하기 때문에 end-to-end training을 막는다.
    이를 극복하기 위해, 우리는 probabilities π\pi로부터 sampling하기 위해 Gumbel-Softmax technique[15]을 적용했다.
    우리는 DD가 kept(유지된) tokens의 mask를 나타내기 때문에 "1" index를 사용한다.
    Gumbel-Softmax의 output은 one-hot tensor이며, 그 expectation은 정확히 π\pi와 같다.
    동시에 Gumble-Softmax은 differentiable하므로 end-to-end training이 가능하다.

  • 두번째 obstacle은 training 중에 token pruning을 시도할 때 발생한다.
    decision mask D^\hat{D}는 보통 unstructured하며, 서로 다른 sample들의 mask에서는 1의 개수가 다양하게 나타난다.
    따라서 D^i=0\hat{D}_i=0인 token을 단순히 버리면 batch 내 sample 간의 token 수가 비균일해져서 병렬 처리가 어려워진다.
    그래서 우리는 token 수를 유지하면서도 제거된 token과 다른 token 간의 상호작용을 줄여야 한다.
    또한 binary mask D^\hat{D}를 사용해 제거할 token을 단순히 0으로 만드는 것이 바람직하지 않다는 것도 발견했다.
    self-attention matrix를 계산할 때 zerod token들이 Sotmax operation에 여전히 영향을 주기 때문이다.
    이를 위해 우리는 dropped token의 영향을 완전히 없앨 수 있는 'attention masking' 전략을 고안했다.
    구체적으로, 우리는 attention matrix를 다음과 같이 계산한다 :
    Eq (10)을 통해 Gi,j=1G_{i, j}=1인 경우 jj-th token이 ii-th token의 update에 기여함을 나타내는 graph를 구성한다.
    각 token에 대해 명시적으로 self-loop를 추가하여 수치적 안정성을 개선했으며, self-loop는 결과에 영향을 미치지 않는다는 점도 쉽게 증명할 수 있다 :
    만약 D^j=0\hat{D}_j = 0이면, jj-th token은 자신을 제외한 다른 token에 기여하지 않는다.
    Eq (11)은 kept tokens들만을 고려하여 계산된 attention matrix와 동일한 형태를 유지하면서도 N×NN \times N의 고정된 shape을 갖는 masked attention matrix A~\tilde{A}를 계산한다.


3.4 Training and Inference

  • 이제 DynamicViT의 training objectives를 설명한다.
    DynamicViT의 training은 favorable decisions을 생성하기 위해 prediction modules을 학습하는 것
    token sparsification에 적응을 위해 backbone을 fine-tuning하는 것을 포함한다.
    BB개의 samples이 minibatch를 이룰 때, 우리는 standard cross-entropy loss를 채택했다 :
    yy는 softmax를 거친 Dynamic ViT의 prediction이고 yˉ\bar{y}는 ground truth이다.

  • token sparsification으로 인해 성능에 미치는 영향을 최소화하기 위해,
    우리는 original backbone network를 teacher model로 사용하고, DynamicViT의 동작이 teacher model에 최대한 가깝게 되기를 바란다.
    구체적으로, 우리는 이 제약을 두가지 측면에서 고려한다.
    첫째, DynamicViT의 남아있는 최종 token이 teacher model의 token과 가깝게 유지되도록 한다.
    이는 일종의 self-distillation으로 볼 수 있다 :
    (내가 이해한 내용 : pretrained model이 teacher가 되고, pruned model이 student가 됨)
    여기서, tit_itit_i'는 각각 DynamicViT과 teacher model의 마지막 block 이후 ii-th token을 나타낸다.
    D^b,s\hat{D}^{b,s}ss-th sparsification stage에서 bb-th sample에 대한 deicsion mask를 나타낸다.

  • 둘째, 우리는 DynamicViT와 teacher 사이의 predictions 차이를 최소화하기 위해 KL divergence를 사용했다.
    yy'은 teacher model의 prediction이다.

  • 마지막으로, 우리는 ratio of the kept tokens을 predefined value(미리 정해둔 값)으로 제한하고자 한다.
    SS stages에 대한 a set of target ratios를 ρ=[ρ(1),...,ρ(S)]\rho=[\rho^{(1), ..., \rho^{(S)}}]라고 할 때,
    우리는 prediction module을 supervise하기 위해 MSE loss를 사용했다 :


몰랐던 개념

profile
Efficient Deep Learning

0개의 댓글