1. Introduction
Proposing general and model-agnostic meta-learning algorithm.
- A task Ti is drawn from task distribution p(T).
- With Ti, the model is trained with only K samples and feedback from Li.
- Tested on new samples from Ti , this test error is considered when imporving model f(generaly expressed as 'model f' since meta-learnable object is differ from algorithms to parameters) with respect to the parameters.(Thus this test error serves as training error.)
- New tasks are sampled from p(T), and meta-performance is measured after learning from K samples.(Tasks used for meta-testing are out held out during meta-training)
Intuition
- Some representations are more transferable.
-> Let Neural networks learn features from such representations, thus broadly applicable to all tasks.
Problems
- Existence of such representations.(Personal question)
- How can we tell which one is the one.
Approach
- It is a relative matter. Not looking for an absolute transferable, but relatively more transferable representation.
- By finding the model that makes most rapid adjustment on new tasks from p(T). Thus able to evoke a large improvement from a small change. More transferable.
Setup
- Ti∼p(T) : distribution of tasks
- Ti : batch of tasks, Ti={τ1,⋯,τi}
- τm : single task, τm={xmj,ymj}j=1J
Algorithm
Inner loop
- For all Ti={τ1,⋯,τi}, sample K data points Ds each and compute gradient descent.
- θi′=θ−α∇θLTi(fθ)
- Get i number of inner loop parameters
- Sample Dq for meta update (outer loop)
- For each task τm, we have θm′,Dmq where m={1,⋯,i}
Outer loop
- For each tasks τm, compute Lm(fθ′,Dq)
- Compute gradient descent with sum of all losses.
- θ←θ−β∇θTi∼P(T)∑LTi(fθi′)
- Note that, since θi′=θ−α∇θLTi(fθ), θ is driven with 2nd order differentiation.
Concequently, the meta-object is as follows:
- Sample batch of task Ti∗∼p(T)/Ti
- Test it on full algorithm
- Purpose of the model includes training.
3. Species of MAML
3-1. Supervised Regression and Classification
MSE for regression
CE-loss for classification
MAML for Few-Shot Supervised Learning
- Note that, K-shot classification tasks use K input/output pairs from each class, thus NK data points for N-way classification.
3.2.Reinforcement Learning
Definitions
- T = { L(x1,a1,...,xH,aH),q(x1),q(xt+1∣xt,at),H }
T : Task (Each learning problems)
- L(xt,at) : Loss function with observation xt, output at
- q(xt) : A distribution of initial observations
- q(xt+1∣xt,at) : A transition distribution
- H : An episode length. A cycle length of generating an output of a query set. (Each time t, model generates samples of length H by chooosing an output at)
Algorithm
Terminology
- task : (Classification for example) Each work given specific classes to perform classification, where this specific classes may not include whole class range. Thus, there might be more than one tasks under one dataset.