class ViT(nn.Module):
def __init__(self, model_name="vit_large_patch16_224_in21k", pretrained=True):
super(ViT, self).__init__()
self.vit = timm.create_model(model_name, pretrained=pretrained)
# Others variants of ViT can be used as well
'''
1 --- 'vit_small_patch16_224'
2 --- 'vit_base_patch16_224'
3 --- 'vit_large_patch16_224',
4 --- 'vit_large_patch32_224'
5 --- 'vit_deit_base_patch16_224'
6 --- 'deit_base_distilled_patch16_224',
'''
# Change the head depending of the dataset used
self.vit.head = nn.Identity()
def forward(self, x):
x = self.vit.patch_embed(x)
cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
if self.vit.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.vit.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.vit.pos_drop(x + self.vit.pos_embed)
x = self.vit.blocks(x)
x = self.vit.norm(x)
return x[:, 0], x[:, 1:]
def __init__(self, model_name="vit_large_patch16_224_in21k", pretrained=True):
에서 pretraine=True
를 이용해 프리트레인을 해서 weight를 가져온다.
timm 문서, timm git
'torch.nn.Identity' :말 그대로 입력과 동일한 tensor를 출력으로 내보내주는 layer다.
사용법 (블로그)
다른 설명 - timm 문서 속
* awar, **keyward
self.vit.cls_token.expand에서
expand 설명
:설명1, 설명2 ,설명3
cls_token
: 설명
cat
: 설명
vit 내부 코드 중layernorm
, 공식문서, transformer 예시분석중 timm vision transforme r분석 중
from fucntion import Pratial
: 설명