[torch] pairwise cos sim 계산

~.~·2022년 12월 6일
0

ML

목록 보기
4/5
post-thumbnail

cos similarity

cos similarity는 벡터간 유사도를 구하기 위한 방법으로 식은 다음과 같다.

분자에 두 벡터의 내적, 분모의 벡터의 norm이 들어가 있다.

연산시 hidden dim = 64인 10000 개의 벡터,
즉 x = (1, 10000, 64)의 벡터에서 각각에서의 pair wise cos를 구하고자한다.

이 때 torch의 matamul과 norm 함수를 이용해 구할 수 있는데 두 함수는 다음과 같이 쓰면 된다.

c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)
print(c @ c.T) 
# tensor([[14., 13.],
#        [13., 18.]])

print(torch.norm(c, dim=-1))
# tensor([3.7417, 4.2426])

## Pairwise cossim 구하기
output = torch.randn(1, 10000, 64)
distance = torch.norm(output, dim = -1)
# 분자
torch.bmm(output, output.transpose(1,2))
# 분모
((distance.T) @ distance)
# 최종
cos_sim = torch.bmm(output, output.transpose(1,2)) /((distance.T) @ distance)

0개의 댓글