[AI] ViT formulation

JAsmine_log·2024년 8월 12일
0

ViT

Formulation

Transformer block과 MSA(Multihead Self-Attention)의 수식을 과 의미를 살펴보자

Transformer Encoder

  • LN : Layer Norm
  • MLP : Multi Layer Perceptron
  • MSA : Multihead Self Attention

(1) [Patch Embbedings]
z0=[xclass;xp1E;xp2;;xpNE]+Eposz_0=[x_{class}; x_p^1E;x_p^2;\cdot\cdot\cdot;x_p^NE]+E_{pos}, ER(p2C)×DE \in {\mathbb{R}^{(p^2\cdot C) \times D}}, EposRN=1)×DE_{pos} \in \mathbb{R}^{N=1)\times D}

(2) [MLP Block]
zl=MSA(LN(zl1))+zl1z^{\prime}_l = MSA(LN(z_{l-1}))+z_{l-1}, l=1Ll=1\cdot\cdot\cdot L

(3) [MLP Block]
zl=MLP(LN(zl))+zlz_l = MLP(LN(z^{\prime}_l))+z^{\prime}_l , l=1Ll=1\cdot\cdot\cdot L

(4) [Position Embeddings]
y=LN(zL0)y=LN(z^0_L)

Multihead Self Attention(instead of transformer encoder)

(5) [q,k,v]=zUzkv,UqkvRD×3Dh[q, k, v]=zU_{zkv}, U{qkv} \in {\mathbb{R}^{D\times} 3D_h}

(6) A=softmax(qkDh),ARN×MA = softmax(qk ^\intercal \sqrt{D_h}), A\in\mathbb{R}^{N \times M}

(7) SA(z)=AvSA(z)=Av

(8) MSA(z)=[SA1(z);SA2(z);;SAk(z)]Umsa,UmsaRkDh×DMSA(z)=[SA_1(z);SA_2(z);\cdots;SA_k(z)]U_{msa}, U_{msa}\in \mathbb{R}^{k \cdot D_h \times D}

위의 수식들은 MSA으로 입력 시퀀스를 변환하여 입력벡터 zzqkvqkv를 생성하고 각 Head에서 self-attention을 적용한 후, 그 결과를 연결하고 다시 변환하여 최종 출력을 생성하는 과정이다.


Explain

(5)입력 벡터 zz로부터 쿼리(queries), 키(keys), 값(values) 벡터를 생성하는 과정

  • zz : 입력 시퀀스의 표현으로, 길이NN의 시퀀스에서 각 위치에서의 벡터ziz_i의 차원은DD
  • UqkvU_{qkv} : zz를 쿼리qq, 키kk, 값vv로 변환하기 위한 가중치 행렬로 크기는D×3DhD \times 3D_h이며, DhD_h는 각 헤드(head)에서의 차원임
  • qq,kk,vv는 모두RN×Dh\mathbb{R}^{N \times D_h}의 차원을 가짐

(6) 어텐션 가중치AA를 계산하는 과정

  • qkqk^\intercal : 쿼리와 키의 내적으로,qqkk의 내적을 계산하면 RN×M\mathbb{R}^{N \times M} 크기의 행렬이 되고, NNMM은 입력 시퀀스의 길이를 나타냄

(7) 단일 어텐션 헤드에서의 셀프 어텐션(SA)을 계산하는 방법을 설명합니다.

  • AA : 앞에서 계산된 어텐션 가중치 행렬
  • vv : 값 벡터
  • AvAv : 어텐션 가중치AA를 값vv에 적용하여 셀프 어텐션 출력을 계산하고, 출력은 RN×Dh\mathbb{R}^{N \times D_h}의 차원

(8) Multihead Self Attention(MSA)

  • SAi(z)SA_i(z) : 각 어텐션 헤드ii의 셀프 어텐션 출력이고, kk개의 헤드를 가지므로kk개의SAi(z)SA_i(z)가 있음
  • [SA1(z);SA2(z);;SAk(z)][SA_1(z); SA_2(z); \cdots; SA_k(z)] : kk개의 어텐션 헤드 출력을 하나로 연결(concatenate)한 것으로, 결과물은 RN×kDh\mathbb{R}^{N \times k \cdot D_h} 차원임
  • UmsaU_{msa} : 연결된 출력을 최종 출력 MSA(z)MSA(z)로 변환하기 위한 가중치 행렬이고, 크기는 kDh×Dk \cdot D_h \times D입니다.
  • MSA(z)MSA(z)RN×D\mathbb{R}^{N \times D}의 차원을 가지는 출력 벡터

Reference
[1] https://github.com/lucidrains/vit-pytorch
[2] https://github.com/huggingface/pytorch-image-models
[3] https://github.com/jankrepl/mildlyoverfitted/tree/master/github_adventures/vision_transformer

profile
Everyday Research & Development

0개의 댓글