ViTs는 computer vision에서 인상적인 성능을 보여줬으나,
quadratic in the nubmer of tokens으로 high computational cost로 인해 computation-constrained applications에 적용하기에 제한되었다.
그러나 이 많은 수의 token이 모두 필요한 것은 아니다.
모든 token이 똑같이 중요한 것은 아니기 때문이다.
본 논문에서는 image classification task에서의 이전 연구를 확장하여
object detection과 instance segmentation을 위한 inference acceleration을 위해 token pruning을 연구한다.
extensive experiments를 통해 다음 네 가지 insights를 제공한다.
ViTs가 처음 소개되고, 많은 vision tasks에서 빠르게 자리잡으며 활발하게 사용되고 있다.
pair-wise token attention을 통한 global reasoning을 수행하기 위한 unique ability를 가지고 있으며, 이것은 장점이자 단점이 될 수 있다.
이는 architectures의 representational power를 향상시키면서 computational foot-print의 상당한 증가를 초래한다.
이는 resource-constrained settings에 ViTs 적용을 제한한다.
큰 계산 비용을 줄이기 위한 실행 가능한(viable) 전략 중 하나는 input space에서 less critical features를 pruning하여 input-aware inference를 수행하는 것이다.
이 전략은 이전에 CNN에 적용되어 FLOP 측정을 개선했지만, convolution 연산의 the intrinsic regularity(고유한 규칙성) 때문에 HW에서 눈에 띄는 speedup을 얻기는 어려웠다.
하지만 ViT의 등장으로 input-space pruning 방식이 가능해졌다.
ViT의 MLP는 pointwise로 작동하고, self-attention은 본질적으로 임의의 수의 token을 처리할 수 있기 때문에 token을 제거하는 방식이 HW 수정 없이도 상당한 speedup을 달성할 수 있다.
초기 연구에서는 gating network를 활용해 less significant token을 식별하거나 class token으로부터 minimal attention을 받는 token을 제거하는 방법이 제안되었다.
이러한 방법들은 효과적이었으나, image classification에만 적용되었으며, object detection과 instance segmentation과 같은 다른 task에 적용된 사례는 거의 없었다.
dense tasks에서의 token pruning에 대한 연구는 여전히 부족한 상태이다.
이 논문에서는,
ViTs에서 token pruning을 활용한 object detection and instance segmentation을 다루며,
기존 image classification과 dense tasks 간의 거리를 좁히기 위한 목적이다.
우리의 preliminary 실험에서, 기존의 방법을 dense tasks에 적용하는 것이 상당한 성능 저하를 이끈다는 것을 발견했다.
광범위한 실험을 통해 성능을 개선하고 model design을 간소화하는 데 유용한 네 가지 key insights를 도출했다.
이 insights를 기반으로 기존 token pruning 방법보다 훨씬 우수한 성능을 보여주는 방법을 개발했다.
Our insights are as follows:
우리는 이러한 design choices들을 평가하고 이를 기반으로 token을 선택적으로 pruning하는 straightforward model을 만들었다.
이를 SViT라고 부른다.
이 model은 기존의 SOTA token pruning models보다 성능이 뛰어나며, box와 mask에서 mAP 손실을 약 ~1.5에서 약 ~0.3으로 줄이고,
dense counterpart의 inference speed를 전체 network에서 최대 34%, backbone에서 46%까지 가속화했다.
Input image의 모든 영역이 동일하게 중요한 것은 아니므로, redundant areas를 제거함으로써 연산을 줄일 수 있으며,
명백한 accuracy loss 없이도 이를 수행할 수 있다.
Spatially ACT는 CNN에서 pixel을 제거하는 방법을 제안한다.
ViT에 대한 다양한 token pruning 기법들이 classification에서 개발되었으며,
여기에는 gating networks, attention scores, reinforcement learning 및 기타 방법 등이 포함된다.
이 중 ToMe[2]는 token을 제거하는 대신 merge하는 방식을 제안한다.
몇몇 연구들은 dense tasks도 고려했다.
SparseViT는 pyramid transformers에 대해 coarse window를 제거하는 반면, isotropic(등방성) transformers에 대해 finer-grained(더 세밀한) token을 제거한다.
SparseDETR은 DETR architecture의 효율성을 개선하는 데 중점을 두는 반면,
우리는 transformer-based backbone의 개선에 집중한다.
STViT-R은 token을 반복적으로 몇 개의 semantic tokens으로 clustering한 뒤 spatial resolution을 복원하는 방식이지만,
우리는 spatial resolution을 유지하며 자세한 position information을 보존한다.
classification과 dense prediction tasks 간의 주요 차이점 중 하나는 pruned tokens을 어떻게 처리하느냐에 있다.
classification에서는 token pruning 방법이 종종 token을 영구적으로 제거한다.
왜냐하면 pruned token이 더 이상 결과에 영향을 미치지 않기 때문이다.
이는 classification이 항상 유지되는 class token에만 의존하기 때문이다.
그러나 dense prediction tasks에서는 pruned token이 backbone에서 더 이상 update되지 않더라도 이후의 detection head에서 여전히 활용될 수 있다.
따라서 pruned token의 이미 계산된 features를 나중에 사용할 수 있도록 유지하는 것이 유리하다.
pruned token을 보존하지 않을 때는 남아있는 token을 원래 위치에 배치하고, pruned token은 zero-pad하여 dense feature map을 복원한다. [61]
반면, pruned token을 보존하는 경우에는 updated token을 대체하면서 pruned token을 그대로 유지하여 점진적으로 feature map을 구축한다.
삭제된 token을 보존하는 것은 제거하는 것만큼 빠르며(Table 2 참조), 다양한 model에서 dense tasks의 성능을 향상시킨다.
pruned token이 feature maps에서 보존되면, 이를 다시 사용해야 하는지 자연스럽게 고려하게 된다.
이 논문에서 "token preserving"은 pruned tokens을 detection head만 사용하는 것을 의미하는 반면,
"token reactivation"은 이러한 token을 backbone의 이후 layer에서 다시 사용할 수 있도록 하는 것을 의미한다.
token reactivation에 대한 반론은 ViT가 가능한 한 많은 computing resource를 유용한 token에 우선적으로 할당해야 한다고 주장할 수 있다. [60]
따라서 pruned token을 reactive하는 것은 이 원칙을 악화시킬 수 있다.
그러나 "informative"의 정의는 각 layer마다 다를 수 있다.
ViT는 각 layer에서 서로 다른 regions에 집중할 수 있기 때문이다. (see supplementary material)
따라서 pruend token을 reactive하는 기능은 ViT가 각 layer에서 서로 다른 region에 집중할 수 있게 하여,
model이 현재 필요한 token에 집중한 후, 다음 block에서 다시 관련된 token으로 돌아갈 수 있게 한다.
또한, 이는 잘못 pruned된 token이 다시 reactivating될 기회를 제공함으로써 token pruning을 더 견고하게 만든다.
이러한 이점은 동일한 block 당 token 사용량 내에서 token의 전체적인 활용을 더욱 효과적으로 만든다.
Section 4에서 우리는 model이 스스로 pruned tokens을 reuse할지 여부와 시기를 학습할 수 있도록 하고,
이 기능이 model의 accuracy를 box AP에서 0.4, mask AP에서 0.3만큼 향상시킬 수 있음을 보여준다.