def validation(model, seen_loader, seen_labels, unseen_loader, unseen_labels, attrs_mat, use_cuda, gamma=None):
# Representation
with torch.no_grad():
seen_reprs = get_reprs(model, seen_loader, use_cuda)
unseen_reprs = get_reprs(model, unseen_loader, use_cuda)
# Labels
uniq_labels = np.unique(np.concatenate([seen_labels, unseen_labels]))
updated_seen_labels = np.searchsorted(uniq_labels, seen_labels)
uniq_updated_seen_labels = np.unique(updated_seen_labels)
updated_unseen_labels = np.searchsorted(uniq_labels, unseen_labels)
uniq_updated_unseen_labels = np.unique(updated_unseen_labels)
uniq_updated_labels = np.unique(np.concatenate([updated_seen_labels, updated_unseen_labels]))
# truncate the attribute matrix
trunc_attrs_mat = attrs_mat[uniq_labels]
#### ZSL ####
zsl_unseen_sim = unseen_reprs @ trunc_attrs_mat[uniq_updated_unseen_labels].T
pred_labels = np.argmax(zsl_unseen_sim, axis=1)
zsl_unseen_predict_labels = uniq_updated_unseen_labels[pred_labels]
zsl_unseen_acc = compute_accuracy(zsl_unseen_predict_labels, updated_unseen_labels, uniq_updated_unseen_labels)
#### GZSL ####
# seen classes
gzsl_seen_sim = softmax(seen_reprs @ trunc_attrs_mat.T, axis=1)
# unseen classes
gzsl_unseen_sim = softmax(unseen_reprs @ trunc_attrs_mat.T, axis=1)
gammas = np.arange(0.0, 1.1, 0.1)
gamma_opt = 0
H_max = 0
gzsl_seen_acc_max = 0
gzsl_unseen_acc_max = 0
# Calibrated stacking
for igamma in range(gammas.shape[0]):
# Calibrated stacking
gamma = gammas[igamma]
gamma_mat = np.zeros(trunc_attrs_mat.shape[0])
gamma_mat[uniq_updated_seen_labels] = gamma
gzsl_seen_pred_labels = np.argmax(gzsl_seen_sim - gamma_mat, axis=1)
# gzsl_seen_predict_labels = uniq_updated_labels[pred_seen_labels]
gzsl_seen_acc = compute_accuracy(gzsl_seen_pred_labels, updated_seen_labels, uniq_updated_seen_labels)
gzsl_unseen_pred_labels = np.argmax(gzsl_unseen_sim - gamma_mat, axis=1)
# gzsl_unseen_predict_labels = uniq_updated_labels[pred_unseen_labels]
gzsl_unseen_acc = compute_accuracy(gzsl_unseen_pred_labels, updated_unseen_labels, uniq_updated_unseen_labels)
H = 2 * gzsl_seen_acc * gzsl_unseen_acc / (gzsl_seen_acc + gzsl_unseen_acc)
if H > H_max:
gzsl_seen_acc_max = gzsl_seen_acc
gzsl_unseen_acc_max = gzsl_unseen_acc
H_max = H
gamma_opt = gamma
print('ZSL: averaged per-class accuracy: {0:.2f}'.format(zsl_unseen_acc * 100))
print('GZSL Seen: averaged per-class accuracy: {0:.2f}'.format(gzsl_seen_acc_max * 100))
print('GZSL Unseen: averaged per-class accuracy: {0:.2f}'.format(gzsl_unseen_acc_max * 100))
print('GZSL: harmonic mean (H): {0:.2f}'.format(H_max * 100))
print('GZSL: gamma: {0:.2f}'.format(gamma_opt))
return gamma_opt
def validation(model, seen_loader, seen_labels, unseen_loader, unseen_labels, attrs_mat, use_cuda, gamma=None):
# Representation
with torch.no_grad():
seen_reprs = get_reprs(model, seen_loader, use_cuda)
unseen_reprs = get_reprs(model, unseen_loader, use_cuda)
with
: 열고 닫기 할때, 닫기 자동으로 필요 할떄 씀
get_repres
: 정리 완료
# Labels
uniq_labels = np.unique(np.concatenate([seen_labels, unseen_labels]))
concat
으로 하나로 만들어줌
updated_seen_labels = np.searchsorted(uniq_labels, seen_labels)
uniq_updated_seen_labels = np.unique(updated_seen_labels)
np.searchsorted
: 내가 찾고하는 위치 알려줌
updated_unseen_labels = np.searchsorted(uniq_labels, unseen_labels)
uniq_updated_unseen_labels = np.unique(updated_unseen_labels)
uniq_updated_labels = np.unique(np.concatenate([updated_seen_labels, updated_unseen_labels]))
trunc_attrs_mat = attrs_mat[uniq_labels]
attribute 뽑아냄
zsl_unseen_sim = unseen_reprs @ trunc_attrs_mat[uniq_updated_unseen_labels].T
pred_labels = np.argmax(zsl_unseen_sim, axis=1)
zsl_unseen_predict_labels = uniq_updated_unseen_labels[pred_labels]
zsl_unseen_acc = compute_accuracy(zsl_unseen_predict_labels, updated_unseen_labels, uniq_updated_unseen_labels)
#### GZSL ####
# seen classes
gzsl_seen_sim = softmax(seen_reprs @ trunc_attrs_mat.T, axis=1)
# unseen classes
gzsl_unseen_sim = softmax(unseen_reprs @ trunc_attrs_mat.T, axis=1)
softmax
-설명
gammas = np.arange(0.0, 1.1, 0.1)
gamma_opt = 0
H_max = 0
gzsl_seen_acc_max = 0
gzsl_unseen_acc_max = 0
np.arange(0.0, 1.1, 0.1)
: 설명
for igamma in range(gammas.shape[0]):
# Calibrated stacking
gamma = gammas[igamma]
gamma_mat = np.zeros(trunc_attrs_mat.shape[0])
gamma_mat[uniq_updated_seen_labels] = gamma
gzsl_seen_pred_labels = np.argmax(gzsl_seen_sim - gamma_mat, axis=1)
# gzsl_seen_predict_labels = uniq_updated_labels[pred_seen_labels]
gzsl_seen_acc = compute_accuracy(gzsl_seen_pred_labels, updated_seen_labels, uniq_updated_seen_labels)
gzsl_unseen_pred_labels = np.argmax(gzsl_unseen_sim - gamma_mat, axis=1)
# gzsl_unseen_predict_labels = uniq_updated_labels[pred_unseen_labels]
gzsl_unseen_acc = compute_accuracy(gzsl_unseen_pred_labels, updated_unseen_labels, uniq_updated_unseen_labels)
H = 2 * gzsl_seen_acc * gzsl_unseen_acc / (gzsl_seen_acc + gzsl_unseen_acc)
if H > H_max:
gzsl_seen_acc_max = gzsl_seen_acc
gzsl_unseen_acc_max = gzsl_unseen_acc
H_max = H
gamma_opt = gamma
gamma = gammas[igamma]
gamma_mat = np.zeros(trunc_attrs_mat.shape[0])
gamma_mat[uniq_updated_seen_labels] = gamma
igmma 는 for문에 의해 range(gammas.shape[0])
를 통해서 순서대로 들어온다.
따라서 gamma = gammas[igama]
에서 gammas = np.arange(0.0, 1.1, 0.1)
가 range 때문에 일정 순서의 숫자가 gamma
에 들어 가게됨
gamma_mat = np.zeros()
로 내부는 0으로 채워짐
하지만 gamma_mat
는 uniq_updated_seen_labels
에 의해 seen label 의 순서에 gamma
의 랜덤 수가 들어 가게됨.
gzsl_seen_pred_labels = np.argmax(gzsl_seen_sim - gamma_mat, axis=1)
# gzsl_seen_predict_labels = uniq_updated_labels[pred_seen_labels]
gzsl_seen_acc = compute_accuracy(gzsl_seen_pred_labels, updated_seen_labels, uniq_updated_seen_labels)
gzsl_seen_pred_labels = np.argmax(gzsl_seen_sim - gamma_mat, axis=1)
은 아까 위에서 새로 계산한 gamma_mat
를 빼서 np.argmax를 수행.
gzsl_seen_acc = compute_accuracy(gzsl_seen_pred_labels, updated_seen_labels, uniq_updated_seen_labels)
를 통해 정확도 다시 계산.
gzsl_unseen_pred_labels = np.argmax(gzsl_unseen_sim - gamma_mat, axis=1)
# gzsl_unseen_predict_labels = uniq_updated_labels[pred_unseen_labels]
gzsl_unseen_acc = compute_accuracy(gzsl_unseen_pred_labels, updated_unseen_labels, uniq_updated_unseen_labels)
unseen 역시 동일 하게 계산.
H = 2 * gzsl_seen_acc * gzsl_unseen_acc / (gzsl_seen_acc + gzsl_unseen_acc)
Hermonic mean 계산
if H > H_max:
gzsl_seen_acc_max = gzsl_seen_acc
gzsl_unseen_acc_max = gzsl_unseen_acc
H_max = H
gamma_opt = gamma
H
값이 H_max
값을 넘을 때마다 갱신
gamma 값 중에서 H
값 갱신 한것만 수정됨
print('ZSL: averaged per-class accuracy: {0:.2f}'.format(zsl_unseen_acc * 100))
print('GZSL Seen: averaged per-class accuracy: {0:.2f}'.format(gzsl_seen_acc_max * 100))
print('GZSL Unseen: averaged per-class accuracy: {0:.2f}'.format(gzsl_unseen_acc_max * 100))
print('GZSL: harmonic mean (H): {0:.2f}'.format(H_max * 100))
print('GZSL: gamma: {0:.2f}'.format(gamma_opt))
return gamma_opt
마지막 valdation class에서 gamma_opt
로 반환 한다.
def test(model, test_seen_loader, test_seen_labels, test_unseen_loader, test_unseen_labels, attrs_mat, use_cuda, gamma):
# Representation
with torch.no_grad():
seen_reprs = get_reprs(model, test_seen_loader, use_cuda)
unseen_reprs = get_reprs(model, test_unseen_loader, use_cuda)
# Labels
uniq_test_seen_labels = np.unique(test_seen_labels)
uniq_test_unseen_labels = np.unique(test_unseen_labels)
# ZSL
zsl_unseen_sim = unseen_reprs @ attrs_mat[uniq_test_unseen_labels].T
predict_labels = np.argmax(zsl_unseen_sim, axis=1)
zsl_unseen_predict_labels = uniq_test_unseen_labels[predict_labels]
zsl_unseen_acc = compute_accuracy(zsl_unseen_predict_labels, test_unseen_labels, uniq_test_unseen_labels)
# Calibrated stacking
Cs_mat = np.zeros(attrs_mat.shape[0])
Cs_mat[uniq_test_seen_labels] = gamma
# GZSL
# seen classes
gzsl_seen_sim = softmax(seen_reprs @ attrs_mat.T, axis=1) - Cs_mat
gzsl_seen_predict_labels = np.argmax(gzsl_seen_sim, axis=1)
gzsl_seen_acc = compute_accuracy(gzsl_seen_predict_labels, test_seen_labels, uniq_test_seen_labels)
# unseen classes
gzsl_unseen_sim = softmax(unseen_reprs @ attrs_mat.T, axis=1) - Cs_mat
gzsl_unseen_predict_labels = np.argmax(gzsl_unseen_sim, axis=1)
gzsl_unseen_acc = compute_accuracy(gzsl_unseen_predict_labels, test_unseen_labels, uniq_test_unseen_labels)
H = 2 * gzsl_unseen_acc * gzsl_seen_acc / (gzsl_unseen_acc + gzsl_seen_acc)
print('ZSL: averaged per-class accuracy: {0:.2f}'.format(zsl_unseen_acc * 100))
print('GZSL Seen: averaged per-class accuracy: {0:.2f}'.format(gzsl_seen_acc * 100))
print('GZSL Unseen: averaged per-class accuracy: {0:.2f}'.format(gzsl_unseen_acc * 100))
print('GZSL: harmonic mean (H): {0:.2f}'.format(H * 100))
print('GZSL: gamma: {0:.2f}'.format(gamma))
validataion과 동일