VIT_ZSL 코드 분석 6부 Baseline model(VIT)

이준석·2022년 6월 18일
0

VIT_ZSL

목록 보기
6/9
class ViT(nn.Module):
    def __init__(self, model_name="vit_large_patch16_224_in21k", pretrained=True):
        super(ViT, self).__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained)
        # Others variants of ViT can be used as well
        '''
        1 --- 'vit_small_patch16_224'
        2 --- 'vit_base_patch16_224'
        3 --- 'vit_large_patch16_224',
        4 --- 'vit_large_patch32_224'
        5 --- 'vit_deit_base_patch16_224'
        6 --- 'deit_base_distilled_patch16_224',
        '''

        # Change the head depending of the dataset used 
        self.vit.head = nn.Identity()
    def forward(self, x):
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)  
        if self.vit.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)
        else:
            x = torch.cat((cls_token, self.vit.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = self.vit.pos_drop(x + self.vit.pos_embed)
        x = self.vit.blocks(x)
        x = self.vit.norm(x)
        
        return x[:, 0], x[:, 1:]
        

상속 : 설명1, 설명2

def __init__(self, model_name="vit_large_patch16_224_in21k", pretrained=True): 에서 pretraine=True 를 이용해 프리트레인을 해서 weight를 가져온다.

  • "vit_large_patch16_224_in21k" 체크 해보기

timm 문서, timm git
'torch.nn.Identity' :말 그대로 입력과 동일한 tensor를 출력으로 내보내주는 layer다.
사용법 (블로그)

다른 설명 - timm 문서 속
* awar, **keyward

model.() 설명

self.vit.cls_token.expand에서 expand 설명 :설명1, 설명2 ,설명3
cls_token : 설명
cat: 설명
vit 내부 코드 중 layernorm, 공식문서, transformer 예시

분석중 timm vision transforme r분석 중 from fucntion import Pratial: 설명

profile
인공지능 전문가가 될레요

0개의 댓글