VIT_ZSL 코드분석 8부 Training and Testing the model

이준석·2022년 6월 20일
0

VIT_ZSL

목록 보기
8/9

Training and Testing the model

Setting the clibration factor

""" Only Run this cell if you are to tune the calibration factor (gamma)
    It is data-dependent, and decided based on the validation set """
gammas = []
for i in range(20):
    train(model, train_data_loader, train_attrbs_tensor, optimizer, use_cuda, lamb_1=1.0)
    lr_scheduler.step()
    gamma = validation(model, val_seen_data_loader, val_seen_labels, val_unseen_data_loader, val_unseen_labels, attrs_mat, use_cuda)
    gammas.append(gamma)
gamma = np.mean(gammas)
print(gamma)

논문에서 다른 참조 논문에 따르면 gamma를 이용하면 더 잘 된다고 하여 이 논문이 참고하였따.

validataion 에서 return을 gamma로 받는다.
그래서 gammas.append 로 gamma를 list로 받은후
np.mean을 통해서 gammas의 값들은 평균을 구한다.
이 코드는 gamma를 구하기 위한 calibration 의 값이다.

Calibartion factor is Set

It is 0.9 for AWA2 and CUB 0.4 for SUN

if DATASET == 'AWA2':
  gamma = 0.9
elif DATASET == 'CUB':
  gamma = 0.9
elif DATASET == 'SUN':
  gamma = 0.4
else:
  print("Please specify the dataset, and set {attr_length} equal to the attribute length")
print('Dataset:', DATASET, '\nGamma:',gamma)

gamma 값을 세팅 한다.

for i in range(80):
    train(model, trainval_data_loader, trainval_attrbs_tensor, optimizer, use_cuda, lamb_1=1.0)
    print(' .... Saving model ...')
    print('Epoch: ', i)
    save_path= str(DATASET) + '__ViT-ZSL__' +'Epoch_' + str(i) + '.pt'
    ckpt_path = './checkpoint/' + str(DATASET)
    path = os.path.join(ckpt_path, save_path)
    torch.save(model.state_dict(), path)

    lr_scheduler.step()
    test(model, test_seen_data_loader, test_seen_labels, test_unseen_data_loader, test_unseen_labels, attrs_mat, use_cuda, gamma)

폴더가 안 만들어져 있으면 os.makedir 를 통해서 만들으면 된다.

profile
인공지능 전문가가 될레요

0개의 댓글