BT

最新技術を追い求めるデベロッパのための情報コミュニティ

寄稿

Topics

地域を選ぶ

InfoQ ホームページ ニュース Google DeepMindのNFNetがディープラーニングを効率化

Google DeepMindのNFNetがディープラーニングを効率化

原文(投稿日:2021/03/26)へのリンク

GoogleのDeepMind AI企業は最近NFNetsをリリースした。これは、ノーマライザーフリーResNet画像分類モデルであり、現在の最先端のEfficientNetよりも8.7倍速いトレーニングパフォーマンスを実現した

GoogleのDeepMindの研究者によると次の通りである(以下のプロットをチェックしてください)。

NFNet-F1モデルは、EfficientNet-B7と同様の精度を実現している。一方で、トレーニングでは8.7倍高速である。我々の最大のモデルは、86.5%のトップ1精度の追加データなしで、新しい総合的な最先端技術に位置付けられる。

大規模な画像認識タスクの場合、通常、ニューラルネットワークはバッチ正規化と呼ばれる手法を使用し、それによってモデルトレーニングがより効率的になる。さらに、ニューラルネットワークがより一般化する助けとなる。つまり、正則化効果がある。

バッチ正規化には、トレーニング時間と推論時間の間の食い違いの動作や、計算オーバーヘッドといったいくつかの欠点があある。計算オーバーヘッドは、後のバックプロパゲーション(ニューラルネットワーク学習プロセス)に必要なネットワーク層ごとの特定のパラメーターの格納によるものである。

DeepMindは、方程式から正規化を削除し、トレーニングパフォーマンスを向上させるためにNFNetを導入した。これに加えて、適応勾配クリッピングと呼ばれる手法が導入されている。これによって、ResNetなどのニューラルネットワークモデルをより大きなバッチサイズで効率的にトレーニングできる。この方法では、同じ精度のEfficientNetと比較して、計算リソース(使用されるGPUの量)ごとにトレーニング時間が20~40%短縮された。

出典: 正規化なしの高性能大規模画像認識

コードはGoogleのDeepMind GitHubで公開され、JAXと呼ばれるこの新しいフレームワークに実装された。 NFNetでフォワードステップを実行するには、次のコードを実行するだけである。

def forward(inputs, is_training):
 model = nfnet.NFNet(num_classes=1000,  variant=variant)
 return model(inputs, is_training=is_training)['logits']
net = hk.without_apply_rng(hk.transform(forward))
fwd = jax.jit(lambda inputs: net.apply(params, inputs, is_training=False))
# We split this into two cells so that we don't repeatedly jit the fwd fn.
logits = fwd(x[None]) # Give X a newaxis to make it batch-size-1
which_class = imagenet_classlist[int(logits.argmax())]
print(f'ImageNet class: {which_class}.')

NFNetsはPytorchにも実装されており、コミュニティがこのリリースを受け入れていることを示している。

import torch
from torch import nn, optim
from torchvision.models import resnet18

from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing

# Ignore fc of a model while applying AGC.
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])

最後に、NFNetに関するYouTubeビデオの再生回数は30,000回を超えている。

 

この記事に星をつける

おすすめ度
スタイル

BT