16.2 Learning distance metrics
Being able to compute the semantic distance between a pair of points, , for or equivalently their similarity , is of crucial importance to tasks such as nearest neighbors classification, self-supervised learning, similarity-based clustering, content-based retrieval etc.
When the input space is , the most common metric is the Mahalanobis distance:
We discuss some way to learn the matrix below. For high dimensional or structured inputs, it is better to first learn an embedding and then to compute distances in this embedding space.
When is a DNN, this is called deep learning metric.
16.2.1 Linear and convex methods
In this section, we discuss approaches to learn the Mahalanobis matrix distance , either directly as a convex problem, or indirectly via a linear projection.
16.2.1.1 Large margin nearest neighbors (LMNN)
Large margin nearest neighbors learns so that the resulting distance metric works well when used by a nearest neighbor classifier.
For each example point , let be the set of target neighbors, usually chosen as the set of points sharing the same label and that are closest in Euclidean distance.
We optimize so that the distance between each point and its target points is minimized:
We also ensure that examples with incorrect labels are far away.
To do so, we ensure that is closer (by a margin ) to its target neighbors than some other points with different labels, called impostors:
where is the hinge loss function.
The overall objective is:
where . This is a convex function, defined over a convex set, which can be minimized using semidefinite programming.
Alternatively, we can parametrize the problem using , and then minimize w.r.t using unconstrained gradient methods. This is no longer convex, but this allows to use a low-dimensional mapping .
For large datasets, we need to tackle the cost of computing .
16.2.1.2 Neighborhood component analysis (NCA)
NCA is another way to learn a mapping such that .
This defines a probability that sample has as it nearest neighbor, using the linear softmax function:
This is a supervised version of the stochastic neighborhood embeddings.
The expected number of correctly classified examples for a 1NN classifier using distance is given by:
Let be the leave-one-out error.
We can minimize w.r.t using gradient methods.
16.2.1.3 Latent coincidence analysis (LCA)
LCA is another way to learn such that by defining a conditional latent variable model for mapping a pair of inputs to a label , which specifies if the inputs are similar (have the same class label) or dissimilar.
Each input is mapped to a low dimensional latent point using:
We then define the probability that the two inputs are similar with:
We can maximize the log marginal likelihood using the EM algorithm:
In the E step, we compute the posterior , which can be done in close form.
In the M step, we solve a weighted least square problem.
EM will monotonically increase the objective, and does not need step size adjustment, unlike gradient methods used in NCA.
It is also possible to fit this model using variational Bayes, as well as various sparse and nonlinear extensions.
16.2.2 Deep metric learning (DML)
When measuring the distances of high-dimensional inputs, it is very useful to first learn an embedding to a lower dimensional “semantic” space, where the distances are more meaningful, and less subject to the curse of dimensionality.
Let be an embedding of the input, preserving its relevant semantic aspect.
The -normalized version, , ensures that all points lie on the hyper-sphere.
We can then measure the distance between two points using the normalized Euclidean distance (where smaller values mean more similar):
or the cosine similarity (where larger values mean more similar):
Both quantities are related by:
The overall approach is called deep metric learning.
The basic idea is to learn embedding function such that similar examples are closer than dissimilar examples.
For example, if we have a labeled dataset, we can create a set of similar examples , and enforce be more similar than .
Note that this method also work in non supervised settings, providing we have other way to define similar pairs.
Before discussing DML, it’s worth mentioning that some recent approaches made invalid claims due to improper experimental comparisons, a common flaw in contemporary ML research. We will therefore focus on (slightly) older and simpler methods, that tend to be more robust.
16.2.3 Classification losses
Suppose we have labeled data with classes. We can fit a classification model in time, and then reuse the hidden features as an embedding function —it is common to use the second-to-last layer since it generalizes better to new classes than the last layer.
This approach is simple and scalable, but it only learns to embed examples on the correct side of the decision boundary. This doesn’t necessarily result in similar examples being placed closed together and dissimilar examples placed far apart.
In addition, this method can only be used with labeled training data.
16.2.4 Ranking losses
In this section, we minimize ranking loss, to ensure that similar examples are closer than dissimilar examples. Most of these methods don’t require labeled data.
16.2.4.1 Pairwise (contrastive) loss and Siamese networks
One of the earliest approach to representation learning from similar/dissimilar pairs was based on minimizing the contrastive loss:
where is a margin parameter.
Intuitively, we want positive pairs to be close, and negative pairs to be further apart than some safety margin.
We minimize this loss over all pairs of data. Naively, this takes time.
Note that we use the same feature extractor for both inputs and to compute the distance, hence the name Siamese network:
16.2.4.2 Triplet loss
One drawback of pairwise losses is that the optimization of positive pairs is independent of the negative pairs, which can make their magnitudes incomparable.
The triplet loss suggest for each example (known as anchor) to find a similar example and a dissimilar example , so that the loss is:
Intuitively, we want positive example to be close to the anchor, and negative example to be further apart from the anchor by a safety margin .
This loss can be computed using a triplet network:
Naively minimizing this loss take time. In practice, we can use a minibatch where the anchor point is the first entry, and there is at least one positive and one negative example.
However, this can still be slow.
16.2.4.3 N-pairs loss
The drawback of triplet loss is that each anchor is only compared to one negative example, therefore the learning signal is not very strong.
A solution is to create a multi-classification problem where we create a set of negatives and one positive for every anchor. This is called the N-pairs loss:
Where .
This is the same as the InfoNCE loss used in the CPC paper.
When , this reduces to the logistic loss:
Compare this to the triplet loss when :
16.2.5 Speeding up ranking loss optimization
The major issue of ranking loss is the or cost of computing the loss function, due to the need to compare pairs or triplet of examples.
We now review speedup tricks.
16.2.5.1 Mining techniques
A key insight is that most negative examples will result in zero loss, so we don’t need to take them all into account.
Instead, we can focus on negative examples that are closer to the anchor than positive examples. These examples are called hard negative.
If is an anchor and its nearest positive example, is a hard negative if:
When the anchor doesn’t have hard negative, we can include semi-hard negatives for which:
This is the technique used by Google FaceNET model, which learns an embedding function for faces, so it can cluster similar looking faces together, to which the user can attach a name.
In practice, the hard negative are chosen from the minibatch, which requires a large batch for diversity.
16.2.5.2 Proxy methods
Even with hard negative mining, triplet loss is expensive.
Instead, it has been suggested to define a set of proxies representing each class and compute the distances between each anchor and proxies, instead of using all examples.
These proxies need to be updated online as the distance metric evolve during training. The overall procedure takes time, where .
More recently, it has been proposed to take multiple prototypes for each class, while still achieving linear time complexity, using a soft triple loss.
16.2.5.3 Optimizing an upper bound
To optimize the triplet loss, it has been proposed to define a fixed proxy or centroid per class, and then use the distance to the proxy as an upper bound on the triplet loss.
Consider a triplet loss without the margin term:
using the triangle inequality, we have:
Therefore:
We can use this to derive a tractable upper bound on the triplet loss:
where
It is clear that can be computed in time.
It has been shown that
where is some constant which depends on the number of centroids.
To ensure the bound is tight, the inter-cluster distances should be large and similar.
This can be enforced by defining the vectors to be one-hot, one per class. These vectors are orthogonal between each other and have unit norm, so that the distance between each other is .
The downside of this approach is that it assumes the embedding layer is dimensional. Two solutions:
- After training, add a linear projection layer mapping to , or use the second-to-last layer of the embedding network
- Sample a large number of points on the -dimensional unit sphere (by sampling from the normal distribution, then normalizing) and then running K-means clustering with .
Interestingly, this paper (opens in a new tab) has been shown that increasing results in higher downstream performance on various retrieval task, where:
is the average intra-class distance and
is the average inter-class distance, where
is the mean embedding for examples of class .
16.2.6 Other training tricks for DML
We present other important details for good DML performance.
i) One important factor is how the minibatch is created.
In classification tasks (at least with balanced classes), selecting examples at random from the training set is usually sufficient.
For DML, we need to ensure that each example has some other examples in the minibatch that are similar and dissimilar to it.
- One approach is to use hard-mining like we previously saw.
- One other is coreset methods applied to previously learned embeddings to select a diverse minibatch at each step.
- The above cited-paper also shows that picking classes and sampling samples per class is a simple sampling method that works well for creating our batches.
ii) Another issue is avoiding overfitting. Since most datasets used in the DML literature are small, it is standard to use image classifiers like GoogLeNet or ResNet pre-trained on ImageNet, and then to fine-tune the model using DML loss.
In addition, it is common to use data augmentation techniques (with self-supervised learning, it is the only way of creating similar pairs)
iii) It has also been proposed to add a spherical embedding constraint (SEC), which is a batchwise regularization term, which encourages all the examples to have the same norm.
The regularizer is the empirical variance of the norms of the unnormalized embeddings in that batch.
This regularizer can be added to any DML loss to modestly improve the training speed and stability, as well as final performances, analogously to how batchnorm is used.