ArcFace: Additive Angular Margin Loss for Deep Face Recognition

치키차카호우·2023년 6월 15일

들어가기 전

위 논문을 읽게 된 계기는 회사에서 Recognition 관련해서 새롭게 업무를 진행해야했기 때문이다.
정확히 Face recognition은 아니고, Face recognition task 처럼 학습에 사용한 클래스(얼굴 인식에서는 특정 사람)는 Gallery에 등록하고, 학습하지 않은 클래스(얼굴 인식에서는 새롭게 등장한 사람)는 등록되지 않았는지 확인하기 위해서 Face recognition 개념을 도입했다.(정확히는 Face identification)

Introduction

  • 해당 논문에서는 Additive Angular Margin 이라는 새로운 Loss 함수를 소개하고 있다.

  • Additive Angular Margin Loss는 Face recognition 모델의 discriminative power를 향상시킨다.

  • Backbone(ConvNet)의 output인 feature vector와 마지막 FC Layer의 weight 사이의 내적 연산은 feature vector와 각 class를 대표하는 center vector 간의 cosine distance를 계산하는 것으로 보았음(이 개념을 아는 것이 논문 및 코드를 이해하는데 아주 중요함!)

    • 각 Class에 해당하는 feature vector가 각 Class의 center vector와 cosine distance가 가깝도록 학습이 되는 것
    • 내적하기 전에 feature vector와 fc layer의 weight(center vector) 에 대해서 normalization을 수행해주는데, 이를 통해 feature vector가 구성하는 manifold가 hypersphere(원/구) 형태로 나오게 되는데, 이 때문에 ArcFace 라고 부르기도 함
  • Training Data 구성이 Clean 한 것도 있지만, labeling이 잘못되어 있거나 이미지가 noisy 한 경우도 있는데, 이를 Sub-Center ArcFace 라는 개념을 도입해 모델의 robustness 강화시킴
    • feature vector가 하나의 positive center에 가깝도록 하는 것이 아니라, k개의 sub-center에 가깝게 학습되도록 함
      • One positive center에 가깝도록 학습하는 것 보다는 constraint를 완화시켜주는 역할을 함

Proposed Approach

ArcFace

  • Classification model을 학습할 때 주로 사용하는 Softmax loss는 아래와 같음
  • 여기서 xix_iii-th sample에 대한 feature vector(yiy_i-th class)

  • WyiW_{y_i}yiy_i에 해당하는 weight vector로서, 위에 언급한 각 class에 해당하는 center vector를 의미함

  • WyiT{W_{y_i}}^Txix_i는 logit이라고 함

  • 위와 같은 loss를 했을 때 전반적으로 좋은 성능을 보이지만, intra-class(동일한 class) 샘플에 대해서는 높은 유사도, inter-class(서로 다른 class) 샘플에 대해서는 분리되도록 명시적(explicitly)으로 학습하는 구조가 아니라서 분류 성능에 한계가 있음

  • 위 식에서 bjb_jbyib_{y_i}를 0으로 두면 logit은 아래와 같이 쓸 수 있음

    • WjTxi=Wj xi cosθj{W_j}^Tx_i=||W_j|| \ ||x_i|| \ cos{\theta}_j
      • cosθjcos{\theta}_j는 weight Wj{W_j}와 feature xix_i 간의 angle을 의미함
  • 각각의 WjW_jl2l_2 normalization 을 통해 Wj=1||W_j||=1 로 만듦

  • xi||x_i|| 또한 l2l_2 normalization 하고, ss로 re-scale 하면 아래와 같은 새로운 식이 됨

  • feature vector는 주로 center feature vector 주위에 분포되어 있음

  • discriminative power를 향상 시키기 위해서 intra-class에 대해서는 compactness, inter-class에 대해서는 discrepancy를 높여야 함

  • 이를 위해 additive angular margin penalty mm을 부여함

  • 최종적으로는 아래와 같은 Additive angular margin loss가 도출됨

  • θ\theta에 따른 decision boundary를 그려보게 되면 아래와 같이 표현할 수 있다.

    • Softmax: margin이 없으므로 Class1, 2에 대한 decision boundary가 딱 붙어있다.

    • 나머지 margin-based loss: margin가 있기 때문에 Class1, 2에 대한 decision boundary 가 다소 떨어져있음을 확인할 수 있다.

      • SphereFace는 각도가 작을수록(0에 가까울수록) decision boundary가 작아지는 단점이 있다
      • CosFace는 decision boundary가 linear 하지 않음
      • ArcFace는 각도 구분 상관없이 linear 한 decision boundary를 가진다.

Code

논문은 그렇게 어려운 내용이 아니다. 동일한 class 끼리는 각도가 최소화 되도록, 다른 class끼리는 각도가 최대화 되도록 학습하는 것이고, margin을 주어 intra-class의 compactness, inter-class의 discrepancy를 강화하겠다는 것이 이 논문 내용의 전부다. 항상 그렇듯, 코드는 논문의 내용과 100% 동일하지 않다. 코드는 arcface-pytorch(https://github.com/ronghuaiyang/arcface-pytorch)를 참고한다. 전체 코드는 아래와 같다.

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin

            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output

step-by-step으로 이해해보자.

forward

        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
  • cosine: input(feature vector)와 각 class에 대한 center vector의 모음인 weight를 내적한 결과이다.
    • 만약, class 개수가 75개라고 하면, 내적 결과는 1x75가 되는데, 이는 각 class weight와 feature vector 사이의 cosine distance를 의미한다.
  • phi: 단순하게 cos(θ+m)cos(\theta +m)이라고 생각하면 된다.
    • 코사인 덧셈법칙: cos(x+y)=cos(x)sin(y)+sin(x)cos(y)cos(x+y)=cos(x)*sin(y)+sin(x)*cos(y)

easy margin

	if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

주로 easy margin을 사용하지 않으므로 else인 부분에 대해서 설명하겠다. case는 아래와 같다.

  • cos(θ)cos(\theta) > cos(πm)cos(\pi-m)일 때
    • cos(θ+m)cos(\theta+m)
  • cos(θ)cos(\theta) <= cos(πm)cos(\pi-m)일 떄,
    • cos(θ)cos(\theta) - msin(πm)m*sin(\pi-m)

이렇게 수식이 복잡한 이유는, θ+m\theta+m 값이 [0,π\pi]에 있지 않아 coscos가 단조감소 함수가 아닐 때가 있는데, 범위를 벗어나더라도 단조감소 함수로 만들기 위함이다.

  • 같은 class 끼리는 θ\theta가 작아져야함(WjW_jxix_i가 같은 class 라면 θ\theta는 작아져야함): θ\theta가 0에 수렴하면 cos(θ)cos(\theta)는 커짐
  • 다른 class 끼리는 θ\theta가 커져야함(WjW_jxix_i가 다른 class 라면 θ\theta는 작아져야함): θ\thetaπ\pi에 가까워지면 cos(θ)cos(\theta)는 작아짐

즉, 1번 case는 θ+m\theta+m 값이 [0,π][0, \pi] 일 때 즉, 단조감소함수 조건을 만족할 때는 cos(θ+m)cos(\theta+m)을 그대로 사용하겠다는 말이다.(θ+m\theta+m <= π\pi <-> θ\theta <= πm\pi-m, 즉, cos(θ)cos(\theta) <= cos(πm)cos(\pi-m))

2번 case는 θ+m\theta+m 값이 [0,π[0, \pi] 가 아닐 때를 의미하는 것이다. 이 때, taylor series로 근사하여 단조감소함수를 만들겠다는 의미이다.

사실, θ+m\theta+m 값이 [0,π[0, \pi]를 벗어났을 때 제일 쉬운 방법은 cos(θ+m)=1cos(\theta+m)=-1로 표현하는 방법이지만, 정확도가 떨어진다.

다른 방법은, 위에서 얘기한 것처럼 taylor series를 이용하여 근사하는 방법이 있다.(위 코드에서는 taylor series에서 1차 선형식까지만 이용하여 근사한다. 물론, 더 근사하면 정확도는 높아지겠지만, 크게 차이가 없는데 computational cost만 높이니까 1차만 사용한 게 아닐까?)

  • 위 근사식에 따라, cos(θ+m)cos(\theta+m) \sim cos(θ)cos(\theta) msin(θ)-m*sin(\theta)
  • θ+m\theta+m >= π\pi 즉, θ\theta >= πm\pi-m 인 경우에는 sin(θ)>=sin(πm)sin(\theta)>=sin(\pi-m) 이므로, 최종적으로 아래식으로 근사된다.
  • cos(θ+m)cos(\theta+m) \sim cos(θ)cos(\theta) msin(πm)-msin(\pi-m)

logit

        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)
  • 이 부분은, target class만 cos(θ+m)cos(\theta+m)만 하기 위함이며, target이 아닌 나머지 class에 대해서는 cos(θ)cos(\theta)로 사용한다.
  • 그리고 scale 값인 s를 곱해줘서 최종적으로 output을 만들며, 이후 cross-entropy loss를 통해 학습하게 된다.
profile
인생 1회차 - 이 정도면 잘하고이따

0개의 댓글