Google DeepMind recently published a new algorithm for curating AI training datasets: multimodal contrastive learning with joint example selection (JEST), which uses a pre-trained model to score the learnability of batches of data. Google's experiments show that image-text models trained with JEST-curated data require 10x less computation than baseline methods.
JEST tries to solve the problem of curating training datasets; that is, filtering the dataset to choose the specific examples that will be most effective in training a model. However, because manually curating datasets is time-consuming, JEST automates the process by using a pre-trained reference model to select the best batches of samples based on their learnability score, which combines the loss from both the reference model and the learner model being trained. The goal is to find batches that have a high loss for the learner but a low one for the reference, which means that the data is both "unlearned and learnable." According to Google,
[W]e find that central to the performance of our framework is the ability to steer the curation process towards the distribution of smaller, well-curated datasets...Crucially, we find this process [enables] strong data quality bootstrapping: a reference model trained on a small curated dataset can effectively guide the curation of a much larger dataset, allowing the training of a model which strongly surpasses the quality of the reference model on many downstream tasks.
JEST is applied during the training process. Given a large super-batch of training data, JEST selects chunks or sub-batches based iteratively by calculating their joint learnability conditioned on the sub-batches previously sampled. The research team found that this improves the quality of the batches, similar to the concept of hard negatives.
Because the learnability score is computed online during training, it imposes some additional compute cost. To address this, JEST uses model approximation for efficient scoring; for example, the vision component of the reference model can drop layers or image patches. The researchers also improved efficiency by training the learner at different image resolutions.
The DeepMind team ran several experiments to evaluate JEST. They first trained an image-text reference model on a curated dataset based on the Web Language Image (WebLI) dataset. They trained learner models using both JEST and compared to models trained using a baseline uniform batch selection. Models trained using JEST achieved the same benchmark performance as baseline models, while requiring 10x fewer training FLOPS.
In a discussion on Hacker News, several users praised DeepMind's work. One wrote:
So the paper itself is pretty significant, I think, from looking at it. The general methodology seems to be: train small model as a discriminatory scoring model on very high quality data...This turns out to be significant FLOPs and quality win, even counting for the initial model training and scoring part of it...As always, appreciate the publishing from DeepMind - this looks like great work.
Another user pointed out that JEST was similar to another method called Cappy, which also uses a "pretrained small scorer." Other related techniques include RHO-LOSS, which inspired JEST and is open-source. Google has not open-sourced JEST.