HuggingFace 및 공식 코드의 Loss는 아래와 같다.
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0
class CLIP(...):
def __init__():
...
self.visual_projection = nn.Linear(self.visual_hidden_size, self.projection_dim, bias=False)
self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.logit_scale_init_value))
...
def forward(..):
# image & text embeds == (Batch_size, Linear Hidden size)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits -> (Batch_size, Batch_size)
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = clip_loss(logits_per_text)
위 코드에서 생기는 의문은 다음과 같다.
cosine similarity
를 구한다고 하는데, 각 임베딩 벡터들의 내적을 계산한다. (각 임베딩 벡터의 크기로 나누어주지 않음)contrastive_loss
에서 label이 왜 torch.arange(len(logits)
로 구현되는지?torch.arange(len(logits))
는 positive pair의 position을 나타낸다.
예를 들어, torch.arange(5)
라면 [0, 1, 2, 3, 4]
가 될 것이고 이는 one-hot labels로 바뀌게 되어 각각, 0행의 0번째, 1행의 1번째... 를 나타내므로 즉, 행렬의 주대각선의 위치를 나타내준다.
정리하면, cross entropy 계산을 할 때 positive pair에 대해 label이 1이고 나머지는 0으로 계산될 것
text loss의 경우 아래 사진의 주황색 부분처럼, 각 행에 대해 cross entropy를 구하는 것이고, image loss의 경우 파란색 부분처럼, 각 열에 대해 cross entropy를 구하는 것이다.