Google’s DeepMind AI company recently released NFNets, a normalizer-free ResNet image classification model that achieved a training performance of 8.7x faster than current state-of-the-art EfficientNet.
According to Google’s DeepMind researchers (checkplot below):
NFNet-F1 model achieves similar accuracy to EfficientNet-B7 while being 8.7×faster to train, and our largest model sets a new overall state of the art without extra data of 86.5% top-1 accuracy.
For large-scale image recognition tasks, usually neural networks use a technique called batch normalization to make the model training more efficient. In addition, it helps neural networks to generalize better, i.e., it has a regularizing effect.
Although batch normalization has some disadvantages such as discrepancy behaviour between training and inference time as well as computational overhead due to the storing of certain parameters per network layer necessary for later backpropagation (neural networks learning process).
DeepMind introduced NFNets to remove normalization from the equation and improve training performance. Adding to this, it introduces a technique called adaptive gradient clipping that allows to train neural network models such as ResNet with a larger batch size in an efficient manner. This method reduced training time by 20-40% per computational resources (amount of GPUs used) compared with EfficientNet with the same accuracy.
Source: High-Performance Large-Scale Image Recognition Without Normalization
The code was published on Google’s DeepMind GitHub, implemented on this new framework called JAX. In order to run a forward step on NFNet, just run the following piece of code:
def forward(inputs, is_training):
model = nfnet.NFNet(num_classes=1000, variant=variant)
return model(inputs, is_training=is_training)['logits']
net = hk.without_apply_rng(hk.transform(forward))
fwd = jax.jit(lambda inputs: net.apply(params, inputs, is_training=False))
# We split this into two cells so that we don't repeatedly jit the fwd fn.
logits = fwd(x[None]) # Give X a newaxis to make it batch-size-1
which_class = imagenet_classlist[int(logits.argmax())]
print(f'ImageNet class: {which_class}.')
NFNets has as well an implementation in Pytorch, which shows community has been receptive to this release:
import torch
from torch import nn, optim
from torchvision.models import resnet18
from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing
# Ignore fc of a model while applying AGC.
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])
Finally, a YouTube video about NFNets had over 30,000 views.