Proba ML
16. Examplar-Based Models
16.3 Learning Distance Metrics

16.2 Learning distance metrics

Being able to compute the semantic distance between a pair of points, d(x,x)R+d(\bold{x,x'})\in\R^+, for x,xX\bold{x,x'}\in\mathcal{X} or equivalently their similarity s(x,x)R+s(\bold{x,x'})\in \R^+, is of crucial importance to tasks such as nearest neighbors classification, self-supervised learning, similarity-based clustering, content-based retrieval etc.

When the input space is X=RD\mathcal{X}=\R^D, the most common metric is the Mahalanobis distance:

dM(x,x)=(xx)M(xx)d_M(\bold{x,x'})=\sqrt{(\bold{x-x'})^\top M(\bold{x-x'})}

We discuss some way to learn the matrix MM below. For high dimensional or structured inputs, it is better to first learn an embedding e=f(x)\bold{e}=f(\bold{x}) and then to compute distances in this embedding space.

When ff is a DNN, this is called deep learning metric.

16.2.1 Linear and convex methods

In this section, we discuss approaches to learn the Mahalanobis matrix distance MM, either directly as a convex problem, or indirectly via a linear projection.

16.2.1.1 Large margin nearest neighbors (LMNN)

Large margin nearest neighbors learns MM so that the resulting distance metric works well when used by a nearest neighbor classifier.

For each example point ii, let NiN_i be the set of target neighbors, usually chosen as the set of KK points sharing the same label and that are closest in Euclidean distance.

We optimize MM so that the distance between each point ii and its target points is minimized:

Lpull(M)=i=1NjNidM(xi,xj)2\mathcal{L}_{\mathrm{pull}}(M)=\sum_{i=1}^N \sum_{j \in N_i} d_M(\bold{x}_i,\bold{x}_j)^2

We also ensure that examples with incorrect labels are far away.

To do so, we ensure that ii is closer (by a margin mm) to its target neighbors jj than some other points ll with different labels, called impostors:

Lpush(M)=i=1NjNil=1NI(jl)[m+dM(xi,xj)2dM(xi,xl)2]+\mathcal{L}_{\mathrm{push}}(M)=\sum_{i=1}^N\sum_{j\in N_i}\sum_{l=1}^N \mathbb{I}(j\neq l)\big[m+d_M(\bold{x}_i,\bold{x}_j)^2-d_M(\bold{x}_i,\bold{x}_l)^2\big]_+

where [z]+=max(z,0)[z]_+=\max(z,0) is the hinge loss function.

The overall objective is:

L(M)=(1λ)Lpull(M)+λLpush(M)\mathcal{L}(M)=(1-\lambda)\mathcal{L}_{\mathrm{pull}}(M)+\lambda\mathcal{L}_{\mathrm{push}}(M)

where 0<λ<10<\lambda <1. This is a convex function, defined over a convex set, which can be minimized using semidefinite programming.

Alternatively, we can parametrize the problem using M=WWM=W^\top W, and then minimize w.r.t WWusing unconstrained gradient methods. This is no longer convex, but this allows to use a low-dimensional mapping WW.

For large datasets, we need to tackle the O(N3)O(N^3) cost of computing Lpush\mathcal{L}_{\mathrm{push}}.

16.2.1.2 Neighborhood component analysis (NCA)

NCA is another way to learn a mapping WW such that M=WWM=W^\top W.

This defines a probability that sample xi\bold{x}_i has xj\bold{x}_j as it nearest neighbor, using the linear softmax function:

pijW=exp(WxiWxj22)liexp(WxiWxl22)p_{ij}^W=\frac{\exp(-|||W\bold{x}_i-W\bold{x}_j||^2_2)}{\sum_{l\neq i} \exp(-||W\bold{x}_i-W\bold{x}_l||^2_2)}

This is a supervised version of the stochastic neighborhood embeddings.

The expected number of correctly classified examples for a 1NN classifier using distance WW is given by:

J(W)=i=1Nji:yi=yjpijWJ(W)=\sum_{i=1}^N\sum_{j\neq i: y_i=y_j} p_{ij}^W

Let L(W)=1J(W)/N\mathcal{L}(W)=1-J(W)/N be the leave-one-out error.

We can minimize L\mathcal{L} w.r.t WW using gradient methods.

16.2.1.3 Latent coincidence analysis (LCA)

LCA is another way to learn WW such that M=WWM=W^\top W by defining a conditional latent variable model for mapping a pair of inputs (x,x)(\bold{x,x'}) to a label y{0,1}y\in \{0,1\}, which specifies if the inputs are similar (have the same class label) or dissimilar.

Each input xRD\bold{x}\in\R^D is mapped to a low dimensional latent point zRL\bold{z}\in\R^L using:

p(zx)=N(zWx,σ2I)p(\bold{z|x})=\mathcal{N}(\bold{z}|W\bold{x},\sigma^2I)

We then define the probability that the two inputs are similar with:

p(y=1z,z)=exp(12κ2zz)p(y=1|\bold{z,z'})=\exp\big(-\frac{1}{2\kappa^2}||\bold{z-z'}||\big)

Screen Shot 2023-10-01 at 12.20.18.png

We can maximize the log marginal likelihood using the EM algorithm:

(W,σ2,κ2)=i=1Nlogp(ynxn,xn)\ell(W,\sigma^2,\kappa^2)=\sum_{i=1}^N \log p(y_n|\bold{x}_n,\bold{x}_n')

In the E step, we compute the posterior p(z,zx,x,y)p(z,z'|\bold{x,x'},y), which can be done in close form.

In the M step, we solve a weighted least square problem.

EM will monotonically increase the objective, and does not need step size adjustment, unlike gradient methods used in NCA.

It is also possible to fit this model using variational Bayes, as well as various sparse and nonlinear extensions.

16.2.2 Deep metric learning (DML)

When measuring the distances of high-dimensional inputs, it is very useful to first learn an embedding to a lower dimensional “semantic” space, where the distances are more meaningful, and less subject to the curse of dimensionality.

Let e=f(x,θ)RL\bold{e}=f(\bold{x},\theta)\in\R^L be an embedding of the input, preserving its relevant semantic aspect.

The 2\ell_2-normalized version, e^=e/e2\bold{\hat{e}}=\bold{e}/||\bold{e}||_2, ensures that all points lie on the hyper-sphere.

We can then measure the distance between two points using the normalized Euclidean distance (where smaller values mean more similar):

d(xi,xj,θ)=e^ie^j22d(\bold{x}_i,\bold{x}_j,\theta)=||\bold{\hat{e}}_i-\bold{\hat{e}}_j||^2_2

or the cosine similarity (where larger values mean more similar):

d(xi,xj,θ)=e^ie^jd(\bold{x}_i,\bold{x}_j,\theta)=\bold{\hat{e}}_i^\top \bold{\hat{e}}_j

Both quantities are related by:

e^ie^j22=(e^ie^j)(e^ie^j)=22e^ie^j||\bold{\hat{e}}_i-\bold{\hat{e}}_j||^2_2=(\bold{\hat{e}}_i-\bold{\hat{e}}_j)^\top (\bold{\hat{e}}_i-\bold{\hat{e}}_j)=2-2\bold{\hat{e}}_i^\top \bold{\hat{e}}_j

The overall approach is called deep metric learning.

The basic idea is to learn embedding function such that similar examples are closer than dissimilar examples.

For example, if we have a labeled dataset, we can create a set of similar examples S={(i,j):yi=yj}\mathcal{S}=\{(i,j):y_i=y_j\}, and enforce (i,j)S(i,j)\in\mathcal{S} be more similar than (i,k)S(i,k)\notin\mathcal{S}.

Note that this method also work in non supervised settings, providing we have other way to define similar pairs.

Before discussing DML, it’s worth mentioning that some recent approaches made invalid claims due to improper experimental comparisons, a common flaw in contemporary ML research. We will therefore focus on (slightly) older and simpler methods, that tend to be more robust.

16.2.3 Classification losses

Suppose we have labeled data with CC classes. We can fit a classification model in O(NC)O(NC) time, and then reuse the hidden features as an embedding function —it is common to use the second-to-last layer since it generalizes better to new classes than the last layer.

This approach is simple and scalable, but it only learns to embed examples on the correct side of the decision boundary. This doesn’t necessarily result in similar examples being placed closed together and dissimilar examples placed far apart.

In addition, this method can only be used with labeled training data.

16.2.4 Ranking losses

In this section, we minimize ranking loss, to ensure that similar examples are closer than dissimilar examples. Most of these methods don’t require labeled data.

16.2.4.1 Pairwise (contrastive) loss and Siamese networks

One of the earliest approach to representation learning from similar/dissimilar pairs was based on minimizing the contrastive loss:

L(θ,xi,xj)=I(yi=yj)d(xi,xj)2+I(yiyj)[md(xi,xj)2]+\mathcal{L}(\theta, \bold{x}_i,\bold{x}_j)=\mathbb{I}(y_i=y_j)d(\bold{x}_i,\bold{x}_j)^2+\mathbb{I}(y_i\neq y_j)[m-d(\bold{x}_i,\bold{x}_j)^2]_+

where m>0m>0 is a margin parameter.

Intuitively, we want positive pairs to be close, and negative pairs to be further apart than some safety margin.

We minimize this loss over all pairs of data. Naively, this takes O(N2)O(N^2) time.

Note that we use the same feature extractor f(.,θ)f(.,\theta) for both inputs xi\bold{x}_i and xj\bold{x}_j to compute the distance, hence the name Siamese network:

Screen Shot 2023-10-01 at 13.02.51.png

16.2.4.2 Triplet loss

One drawback of pairwise losses is that the optimization of positive pairs is independent of the negative pairs, which can make their magnitudes incomparable.

The triplet loss suggest for each example ii (known as anchor) to find a similar example xi+\bold{x}_i^+ and a dissimilar example xi\bold{x}^-_i, so that the loss is:

L(θ,xi,xi+,xi)=[dθ(xi,xi+)2dθ(xi,xi)2+m]+\mathcal{L}(\theta,\bold{x}_i,\bold{x}_i^+,\bold{x}_i^-)=[d_\theta(\bold{x}_i,\bold{x}_i^+)^2-d_\theta(\bold{x}_i,\bold{x}_i^-)^2 + m]_+

Intuitively, we want positive example to be close to the anchor, and negative example to be further apart from the anchor by a safety margin mm.

This loss can be computed using a triplet network:

Screen Shot 2023-10-01 at 13.02.59.png

Naively minimizing this loss take O(N3)O(N^3) time. In practice, we can use a minibatch where the anchor point is the first entry, and there is at least one positive and one negative example.

However, this can still be slow.

16.2.4.3 N-pairs loss

The drawback of triplet loss is that each anchor is only compared to one negative example, therefore the learning signal is not very strong.

A solution is to create a multi-classification problem where we create a set of N1N-1 negatives and one positive for every anchor. This is called the N-pairs loss:

L(θ,x,x+,{xk}k=1N1)=log(1+k=1N1exp(ffkff+))=logexp(ff+)exp(ff+)+k=1N1exp(ffk)\begin{align} \mathcal{L}(\theta,\bold{x},\bold{x}^+,\{\bold{x}^-_k\}^{N-1}_{k=1}) &=\log\Big(1+\sum_{k=1}^{N-1}\exp(f^\top f^-_k-f^\top f^+)\Big)\\ &= -\log\frac{\exp(f^\top f^+)} {\exp\big(f^\top f^+)+ \sum_{k=1}^{N-1}\exp (f^\top f^-_k)} \end{align}

Where f=e^θ(x)f=\bold{\hat{e}}_\theta(\bold{x}).

This is the same as the InfoNCE loss used in the CPC paper.

When N=2N=2, this reduces to the logistic loss:

L(θ,x,x+,x)=log(1+exp(ffff+))\mathcal{L}(\theta,\bold{x},\bold{x}^+,\bold{x}^-)=\log(1+\exp (f^\top f^--f^\top f^+))

Compare this to the triplet loss when m=1m=1:

L(θ,x,x+,x)=[1+ffff+]+\mathcal{L}(\theta,\bold{x},\bold{x}^+,\bold{x}^-)=[1+f^\top f^--f^\top f^+]_+

16.2.5 Speeding up ranking loss optimization

The major issue of ranking loss is the O(N2)O(N^2) or O(N3)O(N^3) cost of computing the loss function, due to the need to compare pairs or triplet of examples.

We now review speedup tricks.

16.2.5.1 Mining techniques

A key insight is that most negative examples will result in zero loss, so we don’t need to take them all into account.

Instead, we can focus on negative examples that are closer to the anchor than positive examples. These examples are called hard negative.

Screen Shot 2023-10-02 at 10.18.17.png

If aa is an anchor and pp its nearest positive example, nn is a hard negative if:

d(xa,xn)<d(xa,xp),yaynd(\bold{x}_a,\bold{x}_n)<d(\bold{x}_a,\bold{x}_p),\quad y_a\neq y_n

When the anchor doesn’t have hard negative, we can include semi-hard negatives for which:

d(xa,xp)<d(xa,xn)<d(xa,xp)+md(\bold{x}_a,\bold{x}_p)<d(\bold{x}_a,\bold{x}_n)<d(\bold{x}_a,\bold{x}_p)+m

This is the technique used by Google FaceNET model, which learns an embedding function for faces, so it can cluster similar looking faces together, to which the user can attach a name.

In practice, the hard negative are chosen from the minibatch, which requires a large batch for diversity.

16.2.5.2 Proxy methods

Even with hard negative mining, triplet loss is expensive.

Instead, it has been suggested to define a set of PP proxies representing each class and compute the distances between each anchor and proxies, instead of using all examples.

These proxies need to be updated online as the distance metric evolve during training. The overall procedure takes O(NP2)O(NP^2) time, where PCP\sim C.

More recently, it has been proposed to take multiple prototypes for each class, while still achieving linear time complexity, using a soft triple loss.

16.2.5.3 Optimizing an upper bound

To optimize the triplet loss, it has been proposed to define a fixed proxy or centroid per class, and then use the distance to the proxy as an upper bound on the triplet loss.

Consider a triplet loss without the margin term:

t(xi,xj,xk)=fifjfifk\ell_t(\bold{x}_i,\bold{x}_j,\bold{x}_k)=||f_i-f_j||-||f_i-f_k||

using the triangle inequality, we have:

fifjficyi+fjcyififkficykfkcyk||f_i-f_j||\le ||f_i-c_{y_i}||+||f_j -c_{y_i}||\\ ||f_i-f_k||\ge ||f_i-c_{y_k}||-||f_k-c_{y_k}||

Therefore:

t(xi,xj,xk)u(xi,xj,xk)ficyificyk+fjcyi+fkcyk\begin{align} \ell_t(\bold{x}_i,\bold{x}_j,\bold{x}_k)&\le \ell_u(\bold{x}_i,\bold{x}_j,\bold{x}_k)\\ &\triangleq ||f_i-c_{y_i}||-||f_i-c_{y_k}||+||f_j-c_{y_i}||+||f_k-c_{y_{k}}|| \end{align}

We can use this to derive a tractable upper bound on the triplet loss:

Lt(D,S)=(i,j)S,(i,k)S,(i,j,k){1,,N}t(xi,xj,xk)(i,j)S,(i,k)S,(i,j,k){1,,N}u(xi,xj,xk)=Ni=1N(xicyi13(C1)m=1,myiCxicm)=Lu(D,S)\begin{align} \mathcal{L}_t(\mathcal{D},\mathcal{S})&=\sum_{(i,j)\in\mathcal{S},(i,k)\notin\mathcal{S},(i,j,k)\in\{1,\dots,N\}}\ell_t(\bold{x}_i,\bold{x}_j,\bold{x}_k) \\ &\leq \sum_{(i,j)\in\mathcal{S},(i,k)\notin\mathcal{S},(i,j,k)\in\{1,\dots,N\}}\ell_u(\bold{x}_i,\bold{x}_j,\bold{x}_k) \\&= N'\sum_{i=1}^N \Big(||\bold{x}_i-c_{y_i}||-\frac{1}{3(C-1)}\sum^C_{m=1,m\neq y_i}||\bold{x}_i-c_m|| \Big) \\&=\mathcal{L}_u(\mathcal{D},\mathcal{S}) \end{align}

where N=3(C1)(NC1)NCN'=3(C-1)(\frac{N}{C}-1)\frac{N}{C}

It is clear that Lu\mathcal{L}_u can be computed in O(NC)O(NC) time.

It has been shown that

0LuLtN3C2K0\leq \mathcal{L}_u-\mathcal{L}_t\leq \frac{N^3}{C^2}K

where KK is some constant which depends on the number of centroids.

To ensure the bound is tight, the inter-cluster distances should be large and similar.

This can be enforced by defining the cm\bold{c}_m vectors to be one-hot, one per class. These vectors are orthogonal between each other and have unit norm, so that the distance between each other is 2\sqrt{2}.

The downside of this approach is that it assumes the embedding layer is L=CL=C dimensional. Two solutions:

  1. After training, add a linear projection layer mapping CC to LL, or use the second-to-last layer of the embedding network
  2. Sample a large number of points on the LL-dimensional unit sphere (by sampling from the normal distribution, then normalizing) and then running K-means clustering with K=CK=C.

Interestingly, this paper (opens in a new tab) has been shown that increasing πintra/πinter\pi_{\mathrm{intra}}/\pi_{\mathrm{inter}} results in higher downstream performance on various retrieval task, where:

πintra=1Zintrac=1Cij,yi=yj=cd(xi,xj)\pi_\mathrm{intra}=\frac{1}{Z_{\mathrm{intra}}}\sum_{c=1}^C \sum_{i\neq j,y_i=y_j=c}d(\bold{x}_i,\bold{x}_j)

is the average intra-class distance and

πinter=1Zinterc=1Cc=1Cd(μc,μc)\pi_{\mathrm{inter}}=\frac{1}{Z_\mathrm{inter}}\sum_{c=1}^C\sum_{c'=1}^C d(\mu_c,\mu_c')

is the average inter-class distance, where

μc=1Zci:yi=cfi\mu_c=\frac{1}{Z_c}\sum_{i:y_i=c}f_i

is the mean embedding for examples of class cc.

16.2.6 Other training tricks for DML

We present other important details for good DML performance.

i) One important factor is how the minibatch is created.

In classification tasks (at least with balanced classes), selecting examples at random from the training set is usually sufficient.

For DML, we need to ensure that each example has some other examples in the minibatch that are similar and dissimilar to it.

  • One approach is to use hard-mining like we previously saw.
  • One other is coreset methods applied to previously learned embeddings to select a diverse minibatch at each step.
  • The above cited-paper also shows that picking B/NcB/N_c classes and sampling Nc=2N_c=2 samples per class is a simple sampling method that works well for creating our batches.

ii) Another issue is avoiding overfitting. Since most datasets used in the DML literature are small, it is standard to use image classifiers like GoogLeNet or ResNet pre-trained on ImageNet, and then to fine-tune the model using DML loss.

In addition, it is common to use data augmentation techniques (with self-supervised learning, it is the only way of creating similar pairs)

iii) It has also been proposed to add a spherical embedding constraint (SEC), which is a batchwise regularization term, which encourages all the examples to have the same norm.

The regularizer is the empirical variance of the norms of the unnormalized embeddings in that batch.

Screen Shot 2023-10-04 at 08.17.51.png

This regularizer can be added to any DML loss to modestly improve the training speed and stability, as well as final performances, analogously to how batchnorm is used.