VIT_ZSL 코드 분석 3부 Train function 및 get, compute

이준석·2022년 6월 16일
0

VIT_ZSL

목록 보기
3/9

Train Function

train

def train(model, data_loader, train_attrbs, optimizer, use_cuda, lamb_1=1.0):
    """returns trained model"""    
    # initialize variables to monitor training and validation loss
    loss_meter = AverageMeter()
    """ train the model  """
    model.train()
    tk = tqdm(data_loader, total=int(len(data_loader)))
    for batch_idx, (data, label) in enumerate(tk):
        # move to GPU
        if use_cuda:
            data,  label = data.cuda(), label.cuda()
        optimizer.zero_grad()
        
        x_g = model.vit(data)[0]
        # global feature
        feat_g = model.mlp_g(x_g)
        logit_g = feat_g @ train_attrbs.T
        loss = lamb_1 * F.cross_entropy(logit_g, label)
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.item(), label.shape[0])
        tk.set_postfix({"loss": loss_meter.avg})
        
    # print training/validation statistics 
    print('Train: Average loss: {:.4f}\n'.format(loss_meter.avg))

def train(model, data_loader, train_attrbs, optimizer, use_cuda, lamb_1=1.0)

train_attrbs: 특성값 넣는 곳인데 나중에 이부분 다시 확인

  • 굉장히 햇갈리는 곳이다.

loss_meter = AverageMeter() 2부에서 한번 정리 했음

model.train() 모델 트레인 함 (4부에서 model 정리할 예정)

pytorch train() 용어 설명

tk = tqdm(data_loader, total=int(len(data_loader))) tqdm 은 시간을 측정할때 많이 씀

tk.set_postfix({"loss": loss_meter.avg})

tqdm : 블로그, 공식 문서, set_postfix
enumerate : 링크, 순서와 객체를 리턴함

mlp_g : 나중에 따로 정의함(5,6부에서)
model.vit, model.mlp_g는 model.ModuleDict 에서 정의 해줌.

  • x_g = model.vit(data)[0] 확인할것 왜 [0] 인지 확인

logit_g = feat_g @ train_attrbs.T flatten 한것과 train_attrb.T 의 합침

@두행렬간 곱 ,
예시 tensor @ tensor.T, 다른 예 tensor.matmul(tensor.T)

lamb_1은 learning_rate

loss = lamb_1 * F.cross_entropy(logit_g, label)
설명, 설명2

item() 함수는 텐서 속의 숫자를 스칼라 값으로 반환하고
loss.backward()
optimizer.step()
설명, 설명2, 설명3


get_reprs & compute_accuracy

def get_reprs(model, data_loader, use_cuda):
    model.eval()
    reprs = []
    for _, (data, _) in enumerate(data_loader):
        if use_cuda:
            data = data.cuda()
        with torch.no_grad():
            # only take the global feature
            feat = model.vit(data)[0]
            feat = model.mlp_g(feat)
        reprs.append(feat.cpu().data.numpy())
    reprs = np.concatenate(reprs, 0)
    return reprs

feat.cpu().data.numpy() 설명
model.eval()1, model.eval()2

def compute_accuracy(pred_labels, true_labels, labels):
    acc_per_class = np.zeros(labels.shape[0])
    for i in range(labels.shape[0]):
        idx = (true_labels == labels[i])
        acc_per_class[i] = np.sum(pred_labels[idx] == true_labels[idx]) / np.sum(idx)
    return np.mean(acc_per_class)
  • 직접실행 시켜 보기 compute_accuracy 구조파악
  • 여기가 잘 이해 안됨, (true_labels == labels[i]) 왜 idx로 무엇이 전달 되는지
  • 또한 np.sum(pred_labels[idx] == true_labels[idx]) / np.sum(idx) 이해가 잘...
profile
인공지능 전문가가 될레요

0개의 댓글