Fritzの刊行するHeartBeatは先日、GoogleのマシンラーニングライブラリであるTensorFlow.jsをブラウザで使って、Chrome Dinosaur Gameのプレーをコンピュータに教える方法について解説した、Aayush Arora氏による記事を公開した。
Chrome Dinosaur Game(T-Rex Gameとも呼ばれる)は5年前、あるユーザがインターネットから切断された状態でWebサイトにアクセスしようとしたときに、Chromeブラウザに現れたものだ。Chrome Dinosaur Gameは単純な無限ランナーで、プレーヤはサボテンを飛び越えたり、障害物を潜り抜けたりする。コントロールは基本的で、スペースバーを押すとジャンプし、下向きの矢印を押すとしゃがみ込む。目標はできる限り長く生き残ることであり、プレーヤが障害物を乗り越えた時間がタイマで測定される。
ゲームの性質を考慮して選択された機能セットは、ゲームのスピード 、現れる障害物の幅 とティラノサウルスからの距離だ。コンピュータはこれら3つの変数をマップして、ジャンプするかしないか、2つの判断のどちらを選択するかを学習する((ゲームのオリジナルバージョンでは、恐竜がしゃがみ込むこともできるが、今回の決定リストではモデル化されない)。コンピュータは試行錯誤によって学習し、ゲームに失敗するたびにトレーニングデータを収集し、蓄積された経験を用いてゲームを再開する。
Tensorflow.jsは、マシンラーニングライブラリとして使用されている。TensorFlowのチュートリアルでは、マシンラーニングの実装で従うべき手順を明確にしている。
今回の例では、トレーニングデータを使用しないで開始するため、最初のステップは実質的に空である。2番目のステップでArora氏は、逐次モデルをベースとして、いずれもシグモイド起動関数(sigmoid activation function)を備えた、入力レイヤと出力レイヤを持つニューラルネットワークを使用した。最初のレイヤは、ゲーム速度、現れる障害物の幅、ティラノサウルスからの距離という、前述の3つの予測変数を持ち、2番目と最後のレイヤの入力として機能する6つのユニットを計算する。最後のレイヤには2つの出力があり、それぞれの値がジャンプする確率、あるいはジャンプしない確率に対応する。
import * as tf from '@tensorflow/tfjs';
dino.model = tf.sequential();
dino.model.add(tf.layers.dense({
inputShape:[3],
activation:'sigmoid',
units:6
}))
dino.model.add(tf.layers.dense({
inputShape:[6],
activation:'sigmoid',
units:2
}))
3番目のステップでは、入力データを、TensorFlow.jsが処理できるテンソル(tensor)に変換する。
dino.model.fit(
tf.tensor2d(dino.training.inputs),
tf.tensor2d(dino.training.labels)
);
3番目のステップにはシャッフリング(shuffling)が実装されていないので、最初は空であるトレーニングセットに対して、コンピュータがゲームに失敗するたびに、段階的に入力が追加されていく。ここでの正規化は、トレーニングセット内の出力値を0から1の間に設定することで実現する。実際には、ティラノサウルスが障害物の回避に失敗した場合、対応する3入力(ゲーム速度、現れる障害物の幅、ティラノサウルスからの距離)は[1, 0]
、[0, 1]
のいずれかにマッピングされて 、第2レイヤの出力を符号化する。ティラノサウルスがジャンプして障害物の回避に失敗した場合、適切な決定はジャンプしないことである: [1, 0]
。逆に、ティラノサウルスがジャンプせずに障害物にぶつかった場合には、ジャンプすることが適切な決定となる: [0, 1]
。
4番目のステップとして、トレーニングデータが利用可能になると、モデルはmeanSquaredError
損失関数とAdamオプティマイザを使って、学習率0.1でトレーニングされる(Adamオプティマイザは実際には非常に効果的で、設定を必要としない)。
dino.model.compile({
loss:'meanSquaredError',
optimizer: tf.train.adam(0.1)
})
5番目のステップは、ゲームの繰り返し中に発生する。ゲームが進行し、3入力の新たな値が計算されると、予測が実行されて、実行すべきタイミングであれば(例えば、ティラノサウルスがジャンプ中でなければ)、"ジャンプする/しない"の判断が行われる。
if (!dino.jumping) {
// whenever the dino is not jumping decide whether it needs to jump or not
let action = 0;// variable for action 1 for jump 0 for not
// call model.predict on the state vecotr after converting it to tensor2d object
const prediction = dino.model.predict(tf.tensor2d([convertStateToVector(state)]));
// the predict function returns a tensor we get the data in a promise as result
// and based don result decide the action
const predictionPromise = prediction.data();
predictionPromise.then((result) => {
// converting prediction to action
if (result[1] > result[0]) {
// we want to jump
action = 1;
// set last jumping state to current state
dino.lastJumpingState = state;
} else {
// set running state to current state
dino.lastRunningState = state;
}
resolve(action);
});
Fritzは、iOSおよびAndroid開発者向けのマシンラーニングプラットフォームである。TensorFlow.jsは、Apache 2.0ライセンスの下で利用可能なオープンソースソフトウェアである。コントリビューションとフィードバックは、TensorFlowのGitHubプロジェクトを通して受け入れられている。いずれもTensorFlowのコントリビューションガイドラインに従うことが必要だ。