Proba ML
19. Learning with Fewer Labeled Examples
19.5 Meta Learning

19.5 Meta-learning

We can think of a learning algorithm as a function AA that maps data to a parameter estimate θ=A(D)\theta=A(\mathcal{D}).

The function AA usually has its own parameter ϕ\phi, such as the initial values for θ\theta or the learning rate. We get θ=A(D;ϕ)\theta=A(\mathcal{D};\phi).

We can imagine learning ϕ\phi itself, given a collection of dataset D1:J\mathcal{D}_{1:J} and some meta-learning algorithm MM, i.e. ϕ=M(D1:J)\phi=M(\mathcal{D}_{1:J}).

We can then apply A(.;ϕ)A(.;\phi) on a new dataset DJ+1\mathcal{D}_{J+1} to learn the parameters θJ+1\theta_{J+1}. This is also called learning to learn.

19.5.1 Model agnostic Meta-learning (MAML)

A natural approach to meta-learning is to use a hierarchical Bayesian model.

Screen Shot 2023-10-30 at 19.11.31.png

We can assume that the parameters θj\theta_j come from a common prior p(θjϕ)p(\theta_j|\phi), which can be used to help pool statistical strength from multiple data-poor problems.

Meta-learning becomes equivalent to learning the prior ϕ\phi. Rather than performing full Bayesian inference, we use the following empirical Bayes approximation:

ϕ=arg maxϕ1Jj=1Jlogp(Dvalidjθ^j(Dtrainj,ϕ))\phi^*=\argmax_\phi\frac{1}{J}\sum_{j=1}^J \log p(\mathcal{D}^j_{\mathrm{valid}}|\hat{\theta}_j(\mathcal{D}^j_{\mathrm{train}},\phi))

where θ^j=θ^(Dtrainj,ϕ)\hat{\theta}_j=\hat{\theta}(\mathcal{D}^j_{\mathrm{train}},\phi) is a point estimate of the parameters of task jj, and we use cross-validation approximation to the marginal likelihood.

To compute the point estimate of the parameters for the target task θJ+1\theta_{J+1}, we use KK steps of a gradient ascent procedure, starting from ϕ\phi with a learning rate η\eta. This is known as model agnostic meta-learning (MAML).

This can be shown to be equivalent to an approximate MAP estimate using a Gaussian prior centered at ϕ\phi, where the strength of the prior is controlled by the number of gradient steps (this is an example of fast adaptation of the task specific weights from the shared prior ϕ\phi).