위 논문을 읽게 된 계기는 회사에서 Recognition 관련해서 새롭게 업무를 진행해야했기 때문이다.
정확히 Face recognition은 아니고, Face recognition task 처럼 학습에 사용한 클래스(얼굴 인식에서는 특정 사람)는 Gallery에 등록하고, 학습하지 않은 클래스(얼굴 인식에서는 새롭게 등장한 사람)는 등록되지 않았는지 확인하기 위해서 Face recognition 개념을 도입했다.(정확히는 Face identification)
해당 논문에서는 Additive Angular Margin 이라는 새로운 Loss 함수를 소개하고 있다.
Additive Angular Margin Loss는 Face recognition 모델의 discriminative power를 향상시킨다.
Backbone(ConvNet)의 output인 feature vector와 마지막 FC Layer의 weight 사이의 내적 연산은 feature vector와 각 class를 대표하는 center vector 간의 cosine distance를 계산하는 것으로 보았음(이 개념을 아는 것이 논문 및 코드를 이해하는데 아주 중요함!)


여기서 는 -th sample에 대한 feature vector(-th class)
는 에 해당하는 weight vector로서, 위에 언급한 각 class에 해당하는 center vector를 의미함
는 logit이라고 함
위와 같은 loss를 했을 때 전반적으로 좋은 성능을 보이지만, intra-class(동일한 class) 샘플에 대해서는 높은 유사도, inter-class(서로 다른 class) 샘플에 대해서는 분리되도록 명시적(explicitly)으로 학습하는 구조가 아니라서 분류 성능에 한계가 있음
위 식에서 및 를 0으로 두면 logit은 아래와 같이 쓸 수 있음
각각의 를 normalization 을 통해 로 만듦
또한 normalization 하고, 로 re-scale 하면 아래와 같은 새로운 식이 됨

feature vector는 주로 center feature vector 주위에 분포되어 있음
discriminative power를 향상 시키기 위해서 intra-class에 대해서는 compactness, inter-class에 대해서는 discrepancy를 높여야 함
이를 위해 additive angular margin penalty 을 부여함
최종적으로는 아래와 같은 Additive angular margin loss가 도출됨

에 따른 decision boundary를 그려보게 되면 아래와 같이 표현할 수 있다.

Softmax: margin이 없으므로 Class1, 2에 대한 decision boundary가 딱 붙어있다.
나머지 margin-based loss: margin가 있기 때문에 Class1, 2에 대한 decision boundary 가 다소 떨어져있음을 확인할 수 있다.
논문은 그렇게 어려운 내용이 아니다. 동일한 class 끼리는 각도가 최소화 되도록, 다른 class끼리는 각도가 최대화 되도록 학습하는 것이고, margin을 주어 intra-class의 compactness, inter-class의 discrepancy를 강화하겠다는 것이 이 논문 내용의 전부다. 항상 그렇듯, 코드는 논문의 내용과 100% 동일하지 않다. 코드는 arcface-pytorch(https://github.com/ronghuaiyang/arcface-pytorch)를 참고한다. 전체 코드는 아래와 같다.
class ArcMarginProduct(nn.Module):
r"""Implement of large margin arc distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
cos(theta + m)
"""
def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
# print(output)
return output
step-by-step으로 이해해보자.
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
주로 easy margin을 사용하지 않으므로 else인 부분에 대해서 설명하겠다. case는 아래와 같다.
이렇게 수식이 복잡한 이유는, 값이 [0,]에 있지 않아 가 단조감소 함수가 아닐 때가 있는데, 범위를 벗어나더라도 단조감소 함수로 만들기 위함이다.
즉, 1번 case는 값이 일 때 즉, 단조감소함수 조건을 만족할 때는 을 그대로 사용하겠다는 말이다.( <= <-> <= , 즉, <= )
2번 case는 값이 ] 가 아닐 때를 의미하는 것이다. 이 때, taylor series로 근사하여 단조감소함수를 만들겠다는 의미이다.
사실, 값이 ]를 벗어났을 때 제일 쉬운 방법은 로 표현하는 방법이지만, 정확도가 떨어진다.
다른 방법은, 위에서 얘기한 것처럼 taylor series를 이용하여 근사하는 방법이 있다.(위 코드에서는 taylor series에서 1차 선형식까지만 이용하여 근사한다. 물론, 더 근사하면 정확도는 높아지겠지만, 크게 차이가 없는데 computational cost만 높이니까 1차만 사용한 게 아닐까?)

# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
# print(output)