VIT_ZSL 코드 분석 7부 model and Optimizer Initialization

이준석·2022년 6월 20일
0

VIT_ZSL

목록 보기
7/9

Model and Optimizer Initialization

import collections
from torch import optim
use_cuda = torch.cuda.is_available()

if DATASET == 'AWA2':
  attr_length = 85
elif DATASET == 'CUB':
  attr_length = 312
elif DATASET == 'SUN':
  attr_length = 102
else:
  print("Please specify the dataset, and set {attr_length} equal to the attribute length")

vit = ViT("vit_large_patch16_224_in21k")
mlp_g = nn.Linear(1024, attr_length, bias=False)

model = nn.ModuleDict({
    "vit": vit,
    "mlp_g": mlp_g})

# finetune all the parameters
for param in model.parameters():
    param.requires_grad = True
    
# move model to GPU if CUDA is available
if use_cuda:
    model = model.cuda()

optimizer = torch.optim.Adam([{"params": model.vit.parameters(), "lr": 0.00001, "weight_decay": 0.0001},
                              {"params": model.mlp_g.parameters(), "lr": 0.001, "weight_decay": 0.00001}])
                              
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 40], gamma=0.5)
#lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)


# train attributes
train_attrbs = attrs_mat[uniq_train_labels]
train_attrbs_tensor = torch.from_numpy(train_attrbs)
# trainval attributes
trainval_attrbs = attrs_mat[uniq_trainval_labels]
trainval_attrbs_tensor = torch.from_numpy(trainval_attrbs)
if use_cuda:
    train_attrbs_tensor = train_attrbs_tensor.cuda()
    trainval_attrbs_tensor = trainval_attrbs_tensor.cuda()
                                  
model    
optimizer = torch.optim.Adam([{"params": model.vit.parameters(), "lr": 0.00001, "weight_decay": 0.0001},
                              {"params": model.mlp_g.parameters(), "lr": 0.001, "weight_decay": 0.00001}])

model.parameters, 설명2
optim.Adam, 설명2

lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 40], gamma=0.5)

lr_scheduler.MultiStepLR, 설명2

torch.from_numpy() : 넘파이를 tensor로 변환 시켜주는거다. 이 함수는 tensor가 바뀌면 numpy 역시 바뀌게 된다.

profile
인공지능 전문가가 될레요

0개의 댓글