BEIT v3, v1을 읽고 읽었지만, 3개 중 가장 어려운 논문인거 같다.
v1 리뷰 / v3 리뷰
Knowledge Distillation이라는 처음보는 개념도 등장하는데,
일단 이 논문부터 하나씩 봐보자
앞선 BEIT v1 논문에서 vision transformer를 학습시킬 때 사용했던 Masked Image Modeling(MIM)
을 이번에도 사용할 건데, v1보다 좀더 high-level semantic을 잘 학습하도록 업그레이드 한 visual tokenizer를 제시한다.
Vector-quantized knowledge distillation
을 활용할 것이다.
patch aggregation strategy
로 discrete한 image patch들로부터 global semantic representation을 잘 보도록 곁들인다.
기본은 BEIT v1과 같이 MIM
방식으로 학습 시키고, image patch
와 visual token
두가지를 학습 재료로 쓸 것이다.
복습: 원본 이미지를 자르고, 토큰화 하여 visual tokens를 만들고 토큰들 중 일부를 mask로 가려 image patches로부터 가린 visual token을 맞추는 방식으로 학습한다.
MIM 학습 방식은 크게 복원 대상을 기준으로 3가지로 나눌 수 있는데, 1. 픽셀값 자체(low-level image elements), 2. hand-crafted features, 3. visual tokens(-BEIT).
하지만 이 세가지 모두 직/간접적으로 픽셀값을 맞추는 거랑 크게 다르지 않다. 하지만 NLP 학습을 보면 high-level semantic으로 학습을 하고 있기 때문에, 거기에서 발전 가능성을 보고 출발 했다.
이 논문에서 Semantic-aware visual tokenizer를 학습 시킬 것이다.
Vector-quantized Knowledge Distillation(VQ-KD)
알고리즘으로 semantic space를 표현한다.
1. VQ-KD encoder가 이미지를 learnable codebook에 따라 discrete token으로 변환 한다.
2. Decoder는 이 discrete token을 가지고 teacher model이 encode한 semantic feature를 복원하는걸 학습한다.
3. VQ-KD가 학습이 끝나면, encoder는 BEIT v1 구조에서 semantic visual tokenizer로 쓴다.
token들이 discrete한 상태이기 때문에 [CLS] token이 global한 image representation을 학습하도록 patch aggregation strategy
를 적용한다.
리마인드 차원에서
BEIT v1의 image patch - visual token 구조
v2에서는 visual token을 만드는데 vector-quantized knowledge distillation
알고리즘을 활용할 예정
patch aggregation strategy
곁들인다.
원본 이미지의 representation을 얻기 위해 Backbone network로 ViT를 쓴다.
원본 이미지를 쪼개 얻은 image patch를 ,
ViT의 output; encoding vector를 로 쓴다.
N은 patch 총 갯수에 해당한다.
VQ-KD를 활용할 예정이고, visual tokenizer
, decoder
로 이루어진 구조이다.
Visual tokenizer
가 이미지를 visual tokens
; discrete codes 으로 매핑한다.
즉, 이미지 를 토큰 으로 토큰화 한다.
이때, 가 codebook의 code 하나에 해당해서 ;
K discrete codebook embedding을 가진다.
이 Tokenizer는 vision transformer encoder
와 quantizer
로 이루어져 있다. Tokenizer
가 먼저 이미지를 vector(로 만들고,
quantizer
가 codebook에서 nearest neighbor을 찾아 codebook embedding
()을 매칭해 준다.
Nearest neighbor를 norm으로 찾는 과정을 수식으로 보면
Visual Tokens를 만들면, decoder에 -normalized codebook embedding
을 넣어 준다.
decoder도 multi-layer transformer구조로 이루어져 있다. decoder의 output vectors
로 teacher model(DINO, CLIP)의 semantic feature를 학습한다.
즉, teacher model의 feature vector 와 decoder의 output 사이의 cosine similarity가 커지도록 학습한다.
Quantization process는 미분이 불가능하기 때문에 encoder output을 학습시키기 애매해서, decoder input단의 gradient를 그대로 encoder의 output에 복붙한다.
전체적인 학습 식은 다음과 같다.
D: tokenizer 학습에 쓴 이미지 Data pool
sg[ ] : forward pass 일땐 identity, backward pass일땐 0으로
하나씩 보자.
문제가 있다. Codebook을 쓸 때 'codebook collapse'가 발생한다. ; code들 중 일부만 사용하는 현상.
Empirical strategy
로 해결 가능하다.
Encoder output
과 codebook 매핑 할 때 norm을 계산한다 했었다. 이때 codebook space embedding
을 32-d로 줄여서 계산하고, 실제 decoder로 넘어가기 전에 dimension을 키워서 전달하는 방식이다.
Exponential moving average
를 계산해 codebook embedding을 업데이트 한다.
BEIT v1 논문에 등장했던 masking 규칙을 따른다.
요약하자면
Pretraining Global Representation
[CLS] 토큰
이 global representation 담당이다.
[CLS] 토큰
을 학습시키기 위해
번째 layer의 output vectors()와 마지막 layer(L번째)의 output에서 CLS 토큰을 concat한다. ->
이제 이걸 Shallow Transformer decoder
에 집어넣는다. figure 3의 오른쪽 두 ViT block이다.
Shallow Transformer decoder
부분은 [CLS] 토큰
학습할 때만 사용해서, 학습 이후엔 안쓴다.
이 shallow transformer
output으로 MIM loss 계산을 한다.
최종 Loss는 마지막 layer에서 계산한 MIM loss + shallow~
에서 계산한 MIM loss 가 된다.
이유 부분 설명을 직관으로 해 두었는데 잘 이해가 안된다.
각종 hyperparameter setting
patch size, ViT 크기, layer 수, epoch 수 등
Decoder로 Deeper ViT를 사용하는게 성능이 MIM에 있어 좋지만, codebook usage와 downstream task 성능은 낮아지는 경향을 보인다.
Codebook의 dimension 수를 줄이는게 codebook utilization을 올린다.
Patch aggregation strategy (CLS 토큰 학습용 부분)
l th layer는 9, head depth를 2로 하는게 좋은 성능을 보였다.
VQ-KD target model로 (teacher model) CLIP, DINO 실험 결과 CLIP 성능이 좋았다.
Visualization of codebook
VQ-VAE 연계 연구들에서 이미지를 토큰화하는 방식들이 제안되었다.
MIM method가 여러 연구서 등장한 바 있다.
논문 볼때마다 공부해야할게 늘어난다.
Knowledge Distillation 알아보자
얼른 Transformer 자세히 파보자. CLS token, ViT 대충 아는걸 채워야겠다.
날잡아서 코드도 한번 파자.