Prioritized Training on Points that are Learnable, Worth Learning, and Not Yet Learnt

Part of an ongoing series highlighting insights from papers that have contributed to the development of best practices for production ML

Josh Tobin
Josh TobinSeptember 01, 2022

Production ML Papers to Know

This is a continuation of Production ML Papers to Know, a series from Gantry highlighting papers we think have been important to the evolving practice of production ML.

This blog-post started life as part of our weekly newsletter in the first week of September, so you may have caught the material already.

Fast Model Editing at Scale

One frustrating truth about production machine learning is that models, especially large ones, tend to fail in spectacular and unexpected ways.  Suppose your model makes a mistake, like answering "who is the president of the US" incorrectly, or suggesting something offensive. 

wtf model output

If this were traditional software, we'd find the bug, write a patch, and deploy it. Relatively straightforward!

However, in ML, fixing one-off issues is not so easy. You could try fine-tuning the model on that data point, but if you just have one data point you'll probably over-fit your model on it. You could also try one-shot or meta-learning, but these techniques usually require you to rethink your whole training process. 

In Fast Model Editing at Scale, the authors instead propose to build a "model editor". 

Model editors

Given a data point you want to change the prediction for, a model editor tries to

  • Fix the prediction for that input
  • Make similar changes for similar inputs
  • Without changing behavior on unrelated inputs

The key idea is to train a second neural network to edit the gradient of the target model you're working on. So you fine-tune your target model like normal, except instead of using the raw gradient you get from back-propogation, you first pass it through an edit model and use that output instead.

So, how can you build an edit model?

We'll need a dataset to train on and a loss function to optimize. The dataset consists of 3 parts.

  • Edit pairs: the data we want to edit
  • Equivalence pairs: similar data that should also be changed by the edit
  • Locality pairs: unrelated data that should be unchanged
  • To compute the edit model loss, we first compute the gradient of the target model on the edit pair. That gradient is the input to the edit model we're trying to train.

Then the edit model loss balances changing the equivalence pair while not changing the locality pair. 


There are a few challenges you'll run into if you try to implement this.

The gradient we're updating can have a lot of parameters (as many as your target network!), which makes it not very tractable to use it as an input to another neural net This model can be tricky to train in practice

The authors share details about how to address these challenges in the paper. 


The results are promising: the technique corrects problems better than fine-tuning while making minimal impact on non-edited data.

It's worth checking out the qualitative results as well. paper table

The Upshot

This is an exciting approach to rapidly fixing model mistakes! So should you use it?

Probably not yet.

  • Training these models might be hard. When papers require a large number of tricks to make them work, it's sometimes a sign that the training will be hard to reproduce
  • There's still small risk of damaging the model when you apply an edit, so you'll need good evaluation to make sure the edited model doesn't have new problems
  • Constructing edit datasets is more of an art today. You need to manually find "similar" edits that should have the same effect as the example you're trying to change, which is problem-dependent

Check out the paper here: