BT

最新技術を追い求めるデベロッパのための情報コミュニティ

寄稿

Topics

地域を選ぶ

InfoQ ホームページ ニュース トレーニング済モデル上で高速クエリ検索インデクスをサポートするTensorFlow Similarity

トレーニング済モデル上で高速クエリ検索インデクスをサポートするTensorFlow Similarity

原文(投稿日:2021/10/19)へのリンク

Kerasフレームワーク作者のFrancois Chollet氏とそのチームは先頃、類似性モデル(similarity model)の構築を容易にするTensorFlow用のPythonライブラリをTensorFlow Similarityという名称でリリースした

類似性学習(similarity learning)とは、画像内の似通った衣類から始まり、顔写真を使用した人物識別に至るまで、類似アイテムの検出を行うプロセスである。ディープラーニングモデルでは、イメージ間の類似性学習の正確性と効率を向上させるために対照学習(contrastive learning)と呼ばれる手法が使用される。対照学習モデルでは、複数の類似する畳み込みネットワークアーキテクチャの生成する埋め込み特徴ベクトル(一連の畳み込みのベクトル出力)が、画像間の類似性の正負ケースを評価/比較する対照性損失(contrastive loss)に送られる。簡単な例は、さまざまな写真の同一人物の顔を識別する処理だ。

TensorFlow Similarityのメリットのひとつは、事前学習モデルを使用した高速クエリ検索インデクスにある。つまり、もし犬を探したければ、そのイメージをAPIに提供すればよい。モデルベースから類似するアイテムが検索されて、高精度なアイテム取得が線形時間内に取得できる。 

もうひとつのメリットは、スクラッチから再トレーニングする必要なく、モデルに新たな検索カテゴリを容易に統合できることだ。

MNISTデータセット用のコード例は、20行以内で記述することができる。

from tensorflow.keras import layers
# Embedding output layer with L2 norm
from tensorflow_similarity.layers import MetricEmbedding 
# Specialized metric loss
from tensorflow_similarity.losses import MultiSimilarityLoss 
# Sub classed keras Model with support for indexing
from tensorflow_similarity.models import SimilarityModel
# Data sampler that pulls datasets directly from tf dataset catalog
from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler
# Nearest neighbor visualizer
from tensorflow_similarity.visualization import viz_neigbors_imgs
# Data sampler that generates balanced batches from MNIST dataset
sampler = TFDatasetMultiShotMemorySampler(dataset_name='mnist', classes_per_batch=10)
# Build a Similarity model using standard Keras layers
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1/255)(inputs)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
outputs = MetricEmbedding(64)(x)
# Build a specialized Similarity model
model = SimilarityModel(inputs, outputs)
# Train Similarity model using contrastive loss
model.compile('adam', loss=MultiSimilarityLoss())
model.fit(sampler, epochs=5)
# Index 100 embedded MNIST examples to make them searchable
sx, sy = sampler.get_slice(0,100)
model.index(x=sx, y=sy, data=sx)
# Find the top 5 most similar indexed MNIST examples for a given example
qx, qy = sampler.get_slice(3713, 1)
nns = model.single_lookup(qx[0])
# Visualize the query example and its top 5 neighbors
viz_neigbors_imgs(qx[0], qy[0], nns)

出典: https://github.com/tensorflow/similarity

図: 類似性モデルは、類似したアイテムは互いに近く、異なるアイテムは遠い距離空間にアイテムを投影するという、埋め込み(embedding)の出力を学習する。

現時点では教師ありモデル(supervised model)のみが使用可能で、APIはまだベータ版である。このAPIはKaras.modelで実装された任意の教師ありモデルを使用できるが、例として示されているのはEfficientNetのみだ。EfficientNetは既存のConv-Netの最高値よりも6.1倍速く、8.4倍小さいという、極めて効率のよい畳み込みネットワークアーキテクチャである。

実装されているのはEfficientNetだけだが、以下のように独自の類似性モデルを開発することもできる。

def get_model():
    inputs = layers.Input(shape=(28, 28, 1))
    x = layers.experimental.preprocessing.Rescaling(1/255)(inputs)
    x = layers.Conv2D(32, 7, activation='relu')(x)
    x = layers.Conv2D(32, 3, activation='relu')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(64, 7, activation='relu')(x)
    x = layers.Conv2D(64, 3, activation='relu')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(64, activation='relu')(x)
    # smaller embeddings will have faster lookup times while a larger embedding will improve the accuracy up to a point.
    outputs = MetricEmbedding(64)(x)
    return SimilarityModel(inputs, outputs)
model = get_model()
model.summary()

その他にもPytorch Metric Learningのような、類似性学習のための非公式なライブラリもあるが、使用には相応の知識が必要なようだ。

コミュニティはこの新しいTensorFlowツールを暖かく迎え入れており、Twitterでは何千ものシェアを得ている。

APIは概念上、相似性モデル、距離メトリクス、距離および損失関数に分けられている。損失関数としてはTriplet LossPN LossMulti Sim LossCircle Lossが利用可能だ。 

図: Tensorflow Similarity APIチャートフロー(出典)

最後に、APIのコードはTensorFlow Similarity Apache 2.0  repository LICENSEに従って、適切な作者表示を含めることで使用できる。

 

 

 

 

この記事に星をつける

おすすめ度
スタイル

BT