Facebookが開発するオープンソースのディープラーニングフレームワークであるPyTorchが、バージョン1.10のリリースを発表した。CUDA Graphs APIのインテグレーション、JITコンパイラのアップデートによるCPUパフォーマンスの向上に加え、Android Networks API(NNAPI)のベータサポートが含まれている。ドメイン固有ライブラリであるTorchVisionとTorchAudioの新バージョンもリリースされた。
PyTorchチームが先日のブログ記事で、今回のリリースのおもな機能について紹介している。新リリースではいくつかの分散トレーニング機能とFXモジュール、torch.specialモジュールがベータ版から安定版に移行した。CUDA Graphs APIのインテグレーションによるCPUオーバーヘッドの低減、複数操作の結合が可能なLLVMベースのJITコンパイラなど、CPUパフォーマンスを向上するアップデートも含まれる。Android NNAPIのサポートがプロトタイプから安定版に移行し、テスト用途でのホスト上でのモジュール実行が可能になった。ディープラーニングアプリケーションの運用デプロイメントを高速化する新SDKのTorchXも含まれている。バージョン1.10全体として、1.9以降の426コントリビュータによる400以上のコミットが含まれている。
AndroidアプリケーションでGPUやNeural Processing Units(NPUs)といったハードウェアアクセラレータを使用可能にするAndroid NNAPIは、昨年プロトタイプサポートが追加されたものだが、 新リリースではこれがベータに移行すると同時に、対象とするオペレーションの拡張、ロード時のフレキシブルなtensor shape、モバイルホスト上でのモデルテスト機能などが追加されている。今回リリースに含まれるベータ機能としては、その他にもCUDA Graphs API統合がある。CUDA Graphsは、GPUに送信されたワークをキャプチャして再生することにより、CPUバウンダリなワークロードの実行時パフォーマンスを向上する。ワークのセットアップとディスパッチをスキップすることで、動的実行の柔軟性と引き換えにCPUのオーバーヘッドを削減するものだ。
今回のリリースでは、nn.Moduleのサブクラスに透過的なRPCを提供するRemoteモジュールや、分散データがプロセス間でグラディエント(gradient)を並行通信する方法をオーバーライドするためのDDP Communication Hook、トレーニング中に必要なメモリ量を削減するZeroRedundancyOptimizerなど、分散トレーニングに関する機能のいくつかが安定版に移行している。SciPyスペシャルモジュール同等のAPIと関数を提供するtorch.specialモジュールも安定版に移行した。
"PyTorchプログラムを変換および簡易化するPythonicなプラットフォーム"のFXモジュールがベータ版から安定版になった。おもなコンポーネントはシンボリックトレーサ、中間表現、Pythonコードジェネレータの3つだ。これらのコンポーネントにより、ModuleサブクラスをGraph表現に変換し、コードでそのGraphを修正し、新たなGraphを既存のPyTorch eager-executionシステムと互換性のあるPyThonソースコードに変換することが可能になる。FXの目標は、オペレータ融合(operator fusion)の実施やインストラクションの挿入といった独自コード用の変換処理を、開発者自身で記述できるようにすることだ。最新リリースでは、このモジュールがベータ版から安定版に移行した。
Hacker News上でのリリースに関する議論では、FXのような機能によって、PyTorchがJAXの方向に向かっているのではないか、という意見があった。これに対して、PyTorchの開発者であるHorace He氏が、次のように返答している。
FXは"JAXの方向に向かう"というより、FXモジュールとして変換処理を記述するためのツールキットと言った方がよいでしょう(類似する点は確かにありますが!)。"JAXの方向"が何を意味するのかについては不明確な部分はありますが、私自身は、1. 構成可能な変換、2. 関数型プログラミング(関数変換に関連する)、ということだと理解しています。PyTorchは1.の方向には進んでいますが、2.は違うと思います。
JAXはGoogleが2019年にオープンソースとして公開した、"ハイパフォーマンスなマシンラーニング研究"用のライブラリだ。Googleのディープラーニングの主力フレームワークであるTensorFlowは、PyTorchの最大のライバルで、今年初めにバージョン2.6をリリースしている。
PyTorch 1.10のリリースノートとコードはGitHubから入手が可能だ。