Kerasフレームワーク作者のFrancois Chollet氏とそのチームは先頃、類似性モデル(similarity model)の構築を容易にするTensorFlow用のPythonライブラリをTensorFlow Similarityという名称でリリースした。
類似性学習(similarity learning)とは、画像内の似通った衣類から始まり、顔写真を使用した人物識別に至るまで、類似アイテムの検出を行うプロセスである。ディープラーニングモデルでは、イメージ間の類似性学習の正確性と効率を向上させるために対照学習(contrastive learning)と呼ばれる手法が使用される。対照学習モデルでは、複数の類似する畳み込みネットワークアーキテクチャの生成する埋め込み特徴ベクトル(一連の畳み込みのベクトル出力)が、画像間の類似性の正負ケースを評価/比較する対照性損失(contrastive loss)に送られる。簡単な例は、さまざまな写真の同一人物の顔を識別する処理だ。
TensorFlow Similarityのメリットのひとつは、事前学習モデルを使用した高速クエリ検索インデクスにある。つまり、もし犬を探したければ、そのイメージをAPIに提供すればよい。モデルベースから類似するアイテムが検索されて、高精度なアイテム取得が線形時間内に取得できる。
もうひとつのメリットは、スクラッチから再トレーニングする必要なく、モデルに新たな検索カテゴリを容易に統合できることだ。
MNISTデータセット用のコード例は、20行以内で記述することができる。
from tensorflow.keras import layers
# Embedding output layer with L2 norm
from tensorflow_similarity.layers import MetricEmbedding
# Specialized metric loss
from tensorflow_similarity.losses import MultiSimilarityLoss
# Sub classed keras Model with support for indexing
from tensorflow_similarity.models import SimilarityModel
# Data sampler that pulls datasets directly from tf dataset catalog
from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler
# Nearest neighbor visualizer
from tensorflow_similarity.visualization import viz_neigbors_imgs
# Data sampler that generates balanced batches from MNIST dataset
sampler = TFDatasetMultiShotMemorySampler(dataset_name='mnist', classes_per_batch=10)
# Build a Similarity model using standard Keras layers
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1/255)(inputs)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
outputs = MetricEmbedding(64)(x)
# Build a specialized Similarity model
model = SimilarityModel(inputs, outputs)
# Train Similarity model using contrastive loss
model.compile('adam', loss=MultiSimilarityLoss())
model.fit(sampler, epochs=5)
# Index 100 embedded MNIST examples to make them searchable
sx, sy = sampler.get_slice(0,100)
model.index(x=sx, y=sy, data=sx)
# Find the top 5 most similar indexed MNIST examples for a given example
qx, qy = sampler.get_slice(3713, 1)
nns = model.single_lookup(qx[0])
# Visualize the query example and its top 5 neighbors
viz_neigbors_imgs(qx[0], qy[0], nns)
出典: https://github.com/tensorflow/similarity
図: 類似性モデルは、類似したアイテムは互いに近く、異なるアイテムは遠い距離空間にアイテムを投影するという、埋め込み(embedding)の出力を学習する。
現時点では教師ありモデル(supervised model)のみが使用可能で、APIはまだベータ版である。このAPIはKaras.modelで実装された任意の教師ありモデルを使用できるが、例として示されているのはEfficientNetのみだ。EfficientNetは既存のConv-Netの最高値よりも6.1倍速く、8.4倍小さいという、極めて効率のよい畳み込みネットワークアーキテクチャである。
実装されているのはEfficientNetだけだが、以下のように独自の類似性モデルを開発することもできる。
def get_model():
inputs = layers.Input(shape=(28, 28, 1))
x = layers.experimental.preprocessing.Rescaling(1/255)(inputs)
x = layers.Conv2D(32, 7, activation='relu')(x)
x = layers.Conv2D(32, 3, activation='relu')(x)
x = layers.MaxPool2D()(x)
x = layers.Conv2D(64, 7, activation='relu')(x)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
# smaller embeddings will have faster lookup times while a larger embedding will improve the accuracy up to a point.
outputs = MetricEmbedding(64)(x)
return SimilarityModel(inputs, outputs)
model = get_model()
model.summary()
その他にもPytorch Metric Learningのような、類似性学習のための非公式なライブラリもあるが、使用には相応の知識が必要なようだ。
コミュニティはこの新しいTensorFlowツールを暖かく迎え入れており、Twitterでは何千ものシェアを得ている。
APIは概念上、相似性モデル、距離メトリクス、距離および損失関数に分けられている。損失関数としてはTriplet Loss、PN Loss、Multi Sim Loss、Circle Lossが利用可能だ。
図: Tensorflow Similarity APIチャートフロー(出典)
最後に、APIのコードはTensorFlow Similarity Apache 2.0 repository LICENSEに従って、適切な作者表示を含めることで使用できる。