
TensorFlow Hubで公開されている超解像度モデル「esrgan-tf2」をAndroidで使ってみます。
手順をは以下になります。
1. モデルをtflite形式に変換する
2. 作成したtfliteファイルにメタデータを追加する
3. AndroidStudioで実行コードを書く
1.~3.の手順のどれもが参考記事が少なく手間取りますので、記録しておく次第です。
1. モデルをtflite形式に変換する
幸いにもTensorFlow Liteで動かすためのチュートリアルをGoogleが公開しています。
この記事はAndroidで動かすための記事ではありませんがtflite形式を作るまでは参考になります。
この記事にした以外tflite形式を作りましょう。
まず、必要なPythonライブラリをインストールします。
pip install matplotlib tensorflow tensorflow-hub
以下のコードを使いtflite形式を作ります。
import tensorflow as tf
import tensorflow_hub as hub
print(tf.__version__)
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
#[1, 128, 128, 3]が、入力画像サイズを示しています。この場合、128x128、24ビットとなっています。
concrete_func.inputs[0].set_shape([1, 128, 128, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
f.write(tflite_model)
また、2021/9/12の時点のcondaでインストールされるTensorFlowで変換できませんでした。
解決方法は、StackOverflowに記載されていました。
これに従い「tensorflow-estimator」のバージョンを2.3.0にします。
pip install --upgrade tensorflow-estimator==2.3.0
ソースコードを「er2lite.py」で保存し、コマンドプロンプトなどで以下を実行します。
python er2lite.py
正常に終了すると、「ESRGAN.tflite」が作られます。
これで、Androidで使えるなら楽なのですが、メタデータを付加しないと容易に使えません。
2. 作成したtfliteファイルにメタデータを追加する
メタデータの付加にはtflite_supportを使います。以下でインストールできます。
pip install tflite-support
メタデータの追加方法は公式ページに記載があるのですが、必要な情報がほとんど掲載されていません。
Colabotryに掲載されているサンプルコードや、tflite-supportのソースコード内の情報を頼りに変換します。
試行錯誤で書いたメタデータの追加コードが以下です。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
FLAGS = flags.FLAGS
flags.DEFINE_string("model_file", None,
"Path and file name to the TFLite model file.")
flags.DEFINE_string("export_directory", None,
"Path to save the TFLite model files with metadata.")
class MetadataPopulatorForModel(object):
def __init__(self, model_file):
self.model_file = model_file
self.metadata_buf = None
def populate(self):
self._create_metadata()
self._populate_metadata()
def _create_metadata(self):
# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "esrgan-tf2"
model_meta.description = ("esrgan-tf2")
model_meta.version = "v1"
model_meta.author = "TensorFlow"
model_meta.license = ("MIT License."
"https://opensource.org/licenses/MIT")
# Creates info for the input, image.
input_image_meta = _metadata_fb.TensorMetadataT()
input_image_meta.name = "selfie_image"
input_image_meta.description = (
"The expected image is with three channels "
"(red, blue, and green) per pixel. Each value in the tensor is between"
" 0 and 255.")
input_image_meta.content = _metadata_fb.ContentT()
input_image_meta.content.contentProperties = (
_metadata_fb.ImagePropertiesT())
input_image_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
input_image_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
input_image_normalization = _metadata_fb.ProcessUnitT()
input_image_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_image_normalization.options = _metadata_fb.NormalizationOptionsT()
input_image_normalization.options.mean = [0]
input_image_normalization.options.std = [1]
input_image_meta.processUnits = [input_image_normalization]
input_image_stats = _metadata_fb.StatsT()
input_image_stats.max = [255.0]
input_image_stats.min = [0.0]
input_image_meta.stats = input_image_stats
# Creates output info, anime image
output_image_meta = _metadata_fb.TensorMetadataT()
output_image_meta.name = "imagef"
output_image_meta.description = "super scaled image"
output_image_meta.content = _metadata_fb.ContentT()
output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
output_image_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
output_image_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
output_image_normalization = _metadata_fb.ProcessUnitT()
output_image_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
output_image_normalization.options = _metadata_fb.NormalizationOptionsT()
output_image_normalization.options.mean = [0]
output_image_normalization.options.std = [1] # 1/127.5
output_image_meta.processUnits = [output_image_normalization]
output_image_stats = _metadata_fb.StatsT()
output_image_stats.max = [255.0]
output_image_stats.min = [0.0]
output_image_meta.stats = output_image_stats
# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_image_meta] # Updated by Margaret
subgraph.outputTensorMetadata = [output_image_meta] # Updated by Margaret
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
self.metadata_buf = b.Output()
def _populate_metadata(self):
"""Populates metadata to the model file."""
populator = _metadata.MetadataPopulator.with_model_file(self.model_file)
populator.load_metadata_buffer(self.metadata_buf)
populator.populate()
def populate_metadata(model_file):
"""Populates the metadata using the populator specified.
Args:
model_file: valid path to the model file.
model_type: a type defined in StyleTransferModelType .
"""
# Populates metadata for the model.
model_file_basename = os.path.basename(model_file)
export_path = os.path.join(FLAGS.export_directory, model_file_basename)
tf.io.gfile.copy(model_file, export_path, overwrite=True)
populator = MetadataPopulatorForModel(export_path)
populator.populate()
# Displays the metadata that was just populated into the tflite model.
displayer = _metadata.MetadataDisplayer.with_model_file(export_path)
export_json_file = os.path.join(
FLAGS.export_directory,
os.path.splitext(model_file_basename)[0] + ".json")
json_file = displayer.get_metadata_json()
with open(export_json_file, "w") as f:
f.write(json_file)
print("Finished populating metadata and associated file to the model:")
print(export_path)
print("The metadata json file has been saved to:")
print(
os.path.join(FLAGS.export_directory,
os.path.splitext(model_file_basename)[0] + ".json"))
def main(_):
populate_metadata(FLAGS.model_file)
if __name__ == "__main__":
app.run(main)
このプログラムを以下のように使ってメタデータを追記します。
python metadata_writer.py --model_file=./ESRGAN.tflite --export_directory=./with_meta
これでモデルデータの完成です。
3. AndroidStudioで実行コードを書く
ここまでくれば、Androidで使うのは簡単(?)です。
AndroidStudioでモデルファイルを取り込んだ後は、以下のようなコードで画像を変換できます。
private fun getHiResAsync(bitmap: Bitmap): Deferred<Bitmap> =
// use async() to create a coroutine in an IO optimized Dispatcher for model inference
coroutineScope.async(Dispatchers.IO) {
val compatList = CompatibilityList()
val options = if(compatList.isDelegateSupportedOnThisDevice && AppPreference.getUseGPU(context)) {
// if the device has a supported GPU, add the GPU delegate
Model.Options.Builder().setDevice(Model.Device.GPU).build()
} else {
// if the GPU is not supported, run on 4 threads
Model.Options.Builder().setNumThreads(4).build()
}
//GPUだと動作しないので、CPUのみで動かす
val model = ESRGAN.newInstance(requireContext(), options)
// Creates inputs for reference.
val selfieImage = TensorImage.fromBitmap(bitmap)
// Runs model inference and gets result.
val outputs = model.process(selfieImage)
val imagef = outputs.imagefAsTensorImage
val imagefBitmap = imagef.bitmap
// Releases model resources if no longer used.
model.close()
return@async imagefBitmap
}
ここまで頑張る価値があるかは人それぞれですが・・・。
0 件のコメント :
コメントを投稿