GoogleのDeepMind AI企業は最近NFNetsをリリースした。これは、ノーマライザーフリーResNet画像分類モデルであり、現在の最先端のEfficientNetよりも8.7倍速いトレーニングパフォーマンスを実現した。
GoogleのDeepMindの研究者によると次の通りである(以下のプロットをチェックしてください)。
NFNet-F1モデルは、EfficientNet-B7と同様の精度を実現している。一方で、トレーニングでは8.7倍高速である。我々の最大のモデルは、86.5%のトップ1精度の追加データなしで、新しい総合的な最先端技術に位置付けられる。
大規模な画像認識タスクの場合、通常、ニューラルネットワークはバッチ正規化と呼ばれる手法を使用し、それによってモデルトレーニングがより効率的になる。さらに、ニューラルネットワークがより一般化する助けとなる。つまり、正則化効果がある。
バッチ正規化には、トレーニング時間と推論時間の間の食い違いの動作や、計算オーバーヘッドといったいくつかの欠点があある。計算オーバーヘッドは、後のバックプロパゲーション(ニューラルネットワーク学習プロセス)に必要なネットワーク層ごとの特定のパラメーターの格納によるものである。
DeepMindは、方程式から正規化を削除し、トレーニングパフォーマンスを向上させるためにNFNetを導入した。これに加えて、適応勾配クリッピングと呼ばれる手法が導入されている。これによって、ResNetなどのニューラルネットワークモデルをより大きなバッチサイズで効率的にトレーニングできる。この方法では、同じ精度のEfficientNetと比較して、計算リソース(使用されるGPUの量)ごとにトレーニング時間が20~40%短縮された。
出典: 正規化なしの高性能大規模画像認識
コードはGoogleのDeepMind GitHubで公開され、JAXと呼ばれるこの新しいフレームワークに実装された。 NFNetでフォワードステップを実行するには、次のコードを実行するだけである。
def forward(inputs, is_training):
model = nfnet.NFNet(num_classes=1000, variant=variant)
return model(inputs, is_training=is_training)['logits']
net = hk.without_apply_rng(hk.transform(forward))
fwd = jax.jit(lambda inputs: net.apply(params, inputs, is_training=False))
# We split this into two cells so that we don't repeatedly jit the fwd fn.
logits = fwd(x[None]) # Give X a newaxis to make it batch-size-1
which_class = imagenet_classlist[int(logits.argmax())]
print(f'ImageNet class: {which_class}.')
NFNetsはPytorchにも実装されており、コミュニティがこのリリースを受け入れていることを示している。
import torch
from torch import nn, optim
from torchvision.models import resnet18
from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing
# Ignore fc of a model while applying AGC.
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])
最後に、NFNetに関するYouTubeビデオの再生回数は30,000回を超えている。