Rao, Yongming, et al. "Dynamicvit: Efficient vision transformers with dynamic token sparsification." Advances in neural information processing systems 34 (2021): 13937-13949.
lightweight prediction module을 design
했다.CNN-type networks의 acceleration의 일반적인 방법 중 하나는 less importance filters를 prune하는 것이다.
ViT와 그 variants에서 input을 처리하는 방식, 즉 input image를 여러 개의 independent patches로 분할하는 방식은
acceleration을 위해 sparsity를 도입할 수 있는 또 다른 방법을 제공한다.
이는 많은 token이 final prediction에 거의 기여하지 않는다는 사실을 고려할 때, input instance의 less importance token을 제거할 수 있음을 의미한다.
이러한 방법은 the token sequence of variable length을 input으로 받을 수 있는 transformer-like models에서만 가능하며,
unstructured pruned input이 self-attention module에 영향을 주지 않기 때문에 가능하다.
반면, CNN에서는 pixel의 일부를 제거해도 unstructured neighborhood 때문에 병렬 처리를 통한 convolution의 가속화가 어렵다.
CNN의 hierarchical architecture가 다양한 vision tasks에서 model efficiency를 향상시켰기 때문에, 우리는 vision transformers에서 downsampling strategy를 탐구하고자 한다.
(실험 결과, unstructured sparsification가 structural downsampling보다 ViTs의 성능을 더 향상시킬 수 있음을 보여준다)
우리의 기본 아이디어는 Figure 1에 나와 있다.
본 연구에서는 dynamic way로 pruning할 token을 결정하기 위해 lightweight prediction module을 사용하는 DynamicViT를 제안한다.
특히 각 input instance에 대해 prediction module은 어떤 token이 uninformative(정보성이 낮고) and need to abandoned(버려져야 하는지)를 결정하는 customized binary decision mask를 생성한다.
이 module은 ViT의 여러 layer에 추가되며, 각 prediction module 이후 점진적으로 더 많은 token을 pruning하여 계층적으로 sparsification을 수행한다.
특정 layer 이후 pruned된 token들은 이후 feed-forward 과정에서 다시 사용되지 않는다.
이 lightweight module이 도입하는 추가의 computatioanl overhead는 uninformative token을 제거함으로써 절약되는 computational overhead에 비해 매우 작다.
이 preidction module은 ViT backbone과 함께 end-to-end manner로 최적화될 수 있다.
이를 위해 두 가지 specialized strategies를 채택했다.
ViT는 처음으로 non-overlapping image patches for the image cls taks를 위해 transformer architecture를 바로 적용했던 연구.
DeiT는 convolution-free transformer에 대한 많은 training techniques를 제안했다.
LV-ViT는 token labeling이라고 불리는 a new training objective를 제안함으로써 성능을 향상시킴
Model acceleration techniques은 edge devices에서 DL model을 deployment할 때 매우 중요하다.
Deep model의 inference speed를 가속화하기 위해 사용할 수 있는 다양한 기술이 있으며,
여기에는 quantization, pruning, low-rank factorization, knowledge distillation 등이 있다.
Transformer model의 inference speed를 가속화하기 위한 여러 연구도 존재한다.
예를 들어, TinyBERT는 transformer의 inference를 가속화하기 위해 distillation methdo를 제안했다.
Star-Transformer는 fc 구조를 start-shaped topology로 대체하여 quadratic space and time complexity를 linear로 줄였다.
우리의 방법은 CNN에서의 neuron이 아닌 less importance tokens을 pruning하여 informative patches을 활용하는 것을 목표로 한다.
여기서 는 -th token에 대한 probability of dropping을 나타내고, 은 probability of keeping을 나타낸다.
그러면 우리는 로부터 sampling함으로써 current deicsion 을 생성하고, 를 update할 수 있다.
여기서 는 Hadamard product(element-wise product)이다.
token은 한 번 dropped된다고 정해졌으면, 그 다음에는 절대 쓰이지 않는다.
비록 우리의 목표가 token sparsification을 수행하는 것이지만, training 중에 이를 실제로 구현하는 것은 간단하지 않다는 것을 발견했다.
첫째, binary decision mask 를 얻기 위해 로부터 sampling하는 것은 non-differentiable하기 때문에 end-to-end training을 막는다.
이를 극복하기 위해, 우리는 probabilities 로부터 sampling하기 위해 Gumbel-Softmax technique[15]을 적용했다.
우리는 가 kept(유지된) tokens의 mask를 나타내기 때문에 "1" index를 사용한다.
Gumbel-Softmax의 output은 one-hot tensor이며, 그 expectation은 정확히 와 같다.
동시에 Gumble-Softmax은 differentiable하므로 end-to-end training이 가능하다.
두번째 obstacle은 training 중에 token pruning을 시도할 때 발생한다.
decision mask 는 보통 unstructured하며, 서로 다른 sample들의 mask에서는 1의 개수가 다양하게 나타난다.
따라서 인 token을 단순히 버리면 batch 내 sample 간의 token 수가 비균일해져서 병렬 처리가 어려워진다.
그래서 우리는 token 수를 유지하면서도 제거된 token과 다른 token 간의 상호작용을 줄여야 한다.
또한 binary mask 를 사용해 제거할 token을 단순히 0으로 만드는 것이 바람직하지 않다는 것도 발견했다.
self-attention matrix를 계산할 때 zerod token들이 Sotmax operation에 여전히 영향을 주기 때문이다.
이를 위해 우리는 dropped token의 영향을 완전히 없앨 수 있는 'attention masking' 전략을 고안했다.
구체적으로, 우리는 attention matrix를 다음과 같이 계산한다 :
Eq (10)을 통해 인 경우 -th token이 -th token의 update에 기여함을 나타내는 graph를 구성한다.
각 token에 대해 명시적으로 self-loop를 추가하여 수치적 안정성을 개선했으며, self-loop는 결과에 영향을 미치지 않는다는 점도 쉽게 증명할 수 있다 :
만약 이면, -th token은 자신을 제외한 다른 token에 기여하지 않는다.
Eq (11)은 kept tokens들만을 고려하여 계산된 attention matrix와 동일한 형태를 유지하면서도 의 고정된 shape을 갖는 masked attention matrix 를 계산한다.
이제 DynamicViT의 training objectives를 설명한다.
DynamicViT의 training은 favorable decisions을 생성하기 위해 prediction modules을 학습하는 것과
token sparsification에 적응을 위해 backbone을 fine-tuning하는 것을 포함한다.
개의 samples이 minibatch를 이룰 때, 우리는 standard cross-entropy loss를 채택했다 :
는 softmax를 거친 Dynamic ViT의 prediction이고 는 ground truth이다.
token sparsification으로 인해 성능에 미치는 영향을 최소화하기 위해,
우리는 original backbone network를 teacher model로 사용하고, DynamicViT의 동작이 teacher model에 최대한 가깝게 되기를 바란다.
구체적으로, 우리는 이 제약을 두가지 측면에서 고려한다.
첫째, DynamicViT의 남아있는 최종 token이 teacher model의 token과 가깝게 유지되도록 한다.
이는 일종의 self-distillation으로 볼 수 있다 :
(내가 이해한 내용 : pretrained model이 teacher가 되고, pruned model이 student가 됨)
여기서, 와 는 각각 DynamicViT과 teacher model의 마지막 block 이후 -th token을 나타낸다.
는 -th sparsification stage에서 -th sample에 대한 deicsion mask를 나타낸다.
둘째, 우리는 DynamicViT와 teacher 사이의 predictions 차이를 최소화하기 위해 KL divergence를 사용했다.
은 teacher model의 prediction이다.
마지막으로, 우리는 ratio of the kept tokens을 predefined value(미리 정해둔 값)으로 제한하고자 한다.
stages에 대한 a set of target ratios를 라고 할 때,
우리는 prediction module을 supervise하기 위해 MSE loss를 사용했다 :