Contrastive Learning

~.~·2022년 12월 2일
0

ML

목록 보기
3/5

feature selection 을 하기 위해 constrastive learning 을 적용해봤다.
anchor, positive, negative가 들어간다.

이때 positive는 anchor와 같은 라벨, negative는 anchor와 다른 라벨이며,
이 라벨이다른 샘플에 대해 거리(여기서는 cos유사도)를 최대화 하는 방향으로 encoder를 훈련시킨다.

로스는 다음과 같다.

anchor와 positive와의 유사도는 높게, negative와의 유사도는 낮게끔 학습이 됨을 확인할 수 있다.

아래의 코드에서
anchor, pos, neg가 한꺼번에 encoder를 거쳐서 loss에 들어간다.
첫번째에는 anchor의 feature vec, 두번째에는 positive sample의 feature vec, 그 후부터는 negative sample의 feature vec이 위치해 있다.

class Contrastive_learning_loss(nn.Module):
    """
    input : (Batch, n, encoder_outdim) 

    return loss
    
    """
    
    def __init__(self, sample_num):
        super().__init__()
        self.sample_num= sample_num   # anchot 제외 , positive sample + negative sample

    def forward(self, encoder_out):
        batch_size = encoder_out.size(0)
        cos = nn.CosineSimilarity(dim =-1 , eps=1e-6)
        
        anchor = encoder_out[:, 0]  
        anchor = anchor.unsqueeze(1).repeat(1, self.sample_num, 1)       ## (Batch, sample_num, out_dim)  -> cos si, 구하기 위해 repeat.....
        
        sample_out = encoder_out[:, 1:]                                  ## (Batch, sample_num, out_dim)  

        cos_sim = cos(anchor, sample_out)                                ## (Batch, sample_num)
        log_loss = - torch.nn.functional.log_softmax(cos_sim, dim = 1)[:,0]  ## (Batch, sample_num) -> (Batch, 1 (sample_num[0]))
        log_loss = torch.sum(log_loss) / batch_size
        
        
        #f_log /= self.sample_num
        return log_loss

0개의 댓글