The focus of this article is on understanding the Gradient Episodic Memory (GEM) model proposed by Lopez-Paz et al. We focus on understanding how the model solves the problem of catastrophic forgetting, how to create a dataset that is closer to real life scenario, and how to determine the accuracy of the model. We also look closely at the functions that the authors have used to understand the implementation details of the model.
Lopez-Paz et al. proposed a model that circumnavigates the limitations imposed by the Empirical Risk Minimization (ERM) on other supervised learning methods. In ERM, the predictor f is found by minimizing
is a loss function penalizing the prediction errors.
ERM is dependent on the assumption that each training example (xᵢ,yᵢ) is an identically and independently distributed (iid) sample from a fixed probability distribution P that describes a single learning task. However, the iid assumption is not applicable to human learning process — during the learning process, humans observe data as an ordered sequence, seldom observed the sample example twice, only memorize a few pieces of data, and the sequence of examples concerns different learning tasks.
This makes it difficult to use the ERM principle and McCloskey et al. showed that the straightforward applications of ERM lead to “catastrophic forgetting” — after learning new tasks, the learner forgets how to solve the previous tasks.
In this article a rotated version of the MNIST dataset is used. However, to demonstrate that proposed GEM model works well for more human-like learning scenarios, a continuum of the rotated dataset is created.
This continuum consists of the feature vectors (of the images)
the target vectors (labels)
and a task descriptor
that identifies the task associated with the pair
More importantly, the examples are not drawn iid from a fixed probability distribution over the triplets (x,t,y) i.e. it is possible to observe a sequence of examples from a specific task before switching to another task.
The test set is also drawn from a continuum where the test pair (x,y)∼ Pₜ can belong to a task that has been observed in the past, or the current task that is being learned (this test set will be used to determine the average accuracy and knowledge transfer), or a task that has not been observed before. In this project, the task descriptors are integers tᵢ = i ∈ Z. In a more general case, it is possible for the task descriptor to be a paragraph of natural language explaining how to solve the i-th task. These rich task descriptors can enable zero-shot learning based on the inferred relations between tasks.
To simulate a “more human-like” setting the model is trained using i) smaller number of training examples per task ii) the number of task is large, iii) each training example concerning each task is only observed once, and iv) both knowledge transfer and forgetting are measured. Each training example is in the form of a triplet (xᵢ,tᵢ,yᵢ) and the tasks are streamed in sequence but not in any particular order.
The Network Architecture
The neural network used has 2 hidden layers of ReLU units and the weights are initialized using “He Weight Initialization” (a variant of the “Xavier Initialization” adapted for ReLU) scheme. Stochastic Gradient Descent (SGD) and Cross-entropy Loss is used to train the network.
To narrow the gap between ERM and the more human-like learning process, the GEM model includes an episodic memory Mₜ that stores a subset of the observed examples from task t.
The memory size is restricted to total M locations, with each task having
memories (m can vary as new tasks are observed). The examples stored in M are used to find predictors
by minimizing the following loss function.
However, minimizing the loss using only the examples stored in the episodic memory leads to overfitting to those examples only and keeping the predictions of the past tasks invariant by distillation makes positive backward transfer impossible.
GEM overcomes this problem by finding the predictor
for the current task t by changing the parameters to ensure that the loss of the current predictor on all the examples stored episodic memory is less than or equal to loss of all the previous predictors (for all the previous tasks) on the same examples. The problem can be stated as
At first glance it seems that this requires the old predictor
to be stored and additional memory. However, this is not required if the loss gradient vectors of the previous tasks are stored instead.
If one or more of the inequality constraints are violated then the loss will increase for at least one of the previous tasks after the proposed update. In that case, the proposed gradient g can be projected on to the closest gradient
So the problem can be described as
The optimization problem above is quadratic and all the constrains are linear. This can be solved using the Quadratic Programming technique. Simplifying the l₂ norm results in the following.
The problem stated above will have p variables that is equal to the number of parameters of the neural network. There can be million of such parameters. However, the dual of the problem will have only t−1 (the number of observed tasks) that is much smaller than p (the number of parameters). The dual of the problem is the following (proven by Dorn in 1960).
The projected gradient update
can be recovered using the following
The accuracy of the model along with its backward transfer is measured each time it is trained on a new task. As discussed previously, backward transfer (BWT) is the influence that learning a new task t has on the performance on a previous task k<t. A positive backward transfer increases the performance on some preceding task k and a negative backward transfer means the opposite. Similarly, a positive forward transfer (FWT) increases the performance of the model on a future task and negative forward transfer reduces the performance. This can be expressed as the following.
The original paper has the implementation details that are necessary to train the model on the desired datasets. In this article, we didn’t go over those details as they might vary from application to application.