Metric Learning - 시각적 이해를 위한 머신러닝 8

zzwon1212·2024년 7월 24일
0

딥러닝

목록 보기
32/33

17. Metric Learning

  • Task

    • Metric learning is the task of learning a distance function over objects.
    • A model takes two or more examples and outputs a (non-negative) score.
    • The distance may have different meaning depending on the data.
    • The model will learn the relationship in the training data, whatever it actually means.
  • Data

    • (weakly) Supervised learning
      • But, often with free of human labor
    • Relative similarity data is easier to collect than traditional labeled one.
      • Photos taken at the same place
      • Videos watched in the same YouTube session
      • Data augmentation
    • Ordinal regression
      • Map to an ordered set of classes

17.1. Learning to Rank

Training data consists of lists of items with some partial order specified between items in each list. The ranking model purposes to rank, i.e., producing a permutation of items in new, unseen lists in a similar way to rankings in the training data.

  • Problem Formulation

    • Point-wise
    • Pair-wise
      • Each training example is an ordered of two items for a query.
      • We train the model to predict a score per item for a given query, where the order between the two items is preserved.
      • Usually, the goal is to minimize the average number of inversions in ranking.
    • List-wise
  • Representation Learning
    We optimize the model to discriminate or order items as much as in the training set, hoping such ability can generalize to unseen items. By observing which item is more similar than another, the model can learn general understanding.

  • Evaluation Metrics

    • Normalized Discounted Cumulative Gain (NDCG)
      DCGp=i=1prelilog2(i+1)DCGp=i=1p2reli1log2(i+1)\text{DCG}_p = \sum_{i=1}^p \frac{rel_i}{\log_2 (i+1)} \qquad \text{DCG}_p = \sum_{i=1}^p \frac{2^{rel_i} - 1}{\log_2 (i+1)}
      • reli=1rel_i = 1 if the ii-th output item is actually relevant to the query, 0 otherwise. Or it is also possible to use actual releveance scores.
    • NDCG is DCG divided by the maximum possible DCG score for that query. The higher, the better.

17.2. Triplet Loss

L(A,P,N)=max(f(A)f(P)2f(A)f(N)2+α,0)\mathcal{L}(A, P, N) = \text{max}( \lVert f(A) - f(P) \rVert_2 - \lVert f(A) - f(N) \rVert_2 + \alpha, 0)

The distance from the anchor to the positive is minimized, and the distance from the anchor to the negative input is maximized.

  • Data

    • Random Netgative
      • The anchor and positive is usually collected from positive pairs (either by human labeling or by implicit data collection)
      • Negative example is usually not explicitly collected, so random negative assignment is common.
        • Since random negatives are often too easy to distinguish, after a few iterations, the model learns nothing.
    • Online Negative Mining
      • To resolve this problem, online negative mining look for hard negatives from the current batch, instead of using one initally assigned.
      • Practically, large batch size is required for good performance, since there are few hard negative with small batch size. However, online negative mining takes O(B2)O(B^2) time for k-NN.
    • Semi-hard Negative Mining
      • Always using the hardest negative can be dangerous. Hard negative can make f(x)=0f(x) = 0 to reduce LL, paying α\alpha.
      • So it is important to pick a negative whose current anchor-negative distance is just above the anchor-positive distance. (Semi-hard negative)
  • FaceNet

  • CDML (Collaborative Deep Metric Learning)

    • CDML freezes feature extractor layers and trains embedding network with triplet loss

    • Limitations
      CDML requires a large batch size due to online negative mining.
  • GCML (Graph Clustering Metric Learning)

17.3. Contrastive Learning

17.3.1. Pairwise Loss

L(W)=i=1PL(W,(Y,X1,X2)i) L(W,(Y,X1,X2)i)=(1Y)LS(DWi)+YLD(DWi)\mathcal{L}(W) = \sum_{i=1}^P L(W, (Y, \vec{X_1}, \vec{X_2})^i) \\ \ \\ L(W, (Y, \vec{X_1}, \vec{X_2})^i) = (1 - Y) L_S (D_W^i) + Y L_D (D_W^i)
  • For a pair of input examples, ground truth distance YY is either 0 (similar) or 1 (dissimilar).
  • For each case, it applies different loss functions, LSL_S (similar) or LDL_D (dissimilar).
  • DWD_W is the distance between X1X_1 and X2X_2 computed by the model.

17.3.2. Negative Samping

  • Softmax
    p(y=cix)=eSijeSjp(y = c_i | \text{x}) = \frac{e^{S_i}}{\sum_j e^{S_j}}
  • Cross-entorpy
    L=i=1nyilogp(y=cix)\mathcal{L} = - \sum_{i=1}^n \text{y}_i \log p(y = c_i | \text{x})
  • Only one term (where yi=1y_i = 1) alives in cross-entropy, but softmax depends on all other probabilities due to the denominator.
  • This makes the loss depend on every output in the network, which means every network parameter will have a non-zero gradient. So the model needs to update for every training example.
  • Can we just sample some negatives, instead of computing all of them?
    • For most irrelevant labels, p(y=cix)0p(y = c_i | \text{x}) \approx 0.

17.3.3. SimCLR (A Simple Framework for Contrastive Learning of Visual Representations)

  • Architecture

  • Loss

    i,j=logexp(sim(zi,zj)/τ)k=12N1[ki]exp(sim(zi,zk)/τ)\ell_{i,j} = - \log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \ne i]} \exp(\text{sim}(z_i, z_k) / \tau)}
  • Self-supervised learning: no label is required.

17.3.4. Noise Contrastive Estimator (NCE)

  • Idea

    • MM examples {x1,...,xM}\{ \text{x}_1, ..., \text{x}_M \} from observed probability distribution pmp_m
    • NN examples {y1,...,yN}\{ \text{y}_1, ..., \text{y}_N \} from fake probability distribution pnp_n, generated artificially
    • Now, NCE trains the model to distinguish if each sample came from pmp_m or pnp_n.
      • Binary classification task with logistic regression
    • Summary
      • Sample some fake examples
      • Maximize the score of real examples
      • Minimize the score of fake ones
      • No need to compute scores for all classes
  • Loss

    i=1Mln[h(xi;θ)]+j=1Nln[1h(yj;θ)]\sum_{i=1}^M \ln [h(\text{x}_i; \theta)] + \sum_{j=1}^N \ln [1 - h(\text{y}_j; \theta)]

📙 강의

profile
JUST DO IT.

0개의 댓글

관련 채용 정보