[Android][TensorFlow Lite]esrgan-tf2(超解像度)をAndroidで動かしてみる。

TensorFlow Hubで公開されている超解像度モデル「esrgan-tf2」をAndroidで使ってみます。

手順をは以下になります。

1. モデルをtflite形式に変換する
2. 作成したtfliteファイルにメタデータを追加する
3. AndroidStudioで実行コードを書く

1.~3.の手順のどれもが参考記事が少なく手間取りますので、記録しておく次第です。

1. モデルをtflite形式に変換する

幸いにもTensorFlow Liteで動かすためのチュートリアルをGoogleが公開しています。

この記事はAndroidで動かすための記事ではありませんがtflite形式を作るまでは参考になります。
この記事にした以外tflite形式を作りましょう。

まず、必要なPythonライブラリをインストールします。

pip install matplotlib tensorflow tensorflow-hub

以下のコードを使いtflite形式を作ります。

import tensorflow as tf
import tensorflow_hub as hub
print(tf.__version__)

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
#[1, 128, 128, 3]が、入力画像サイズを示しています。この場合、128x128、24ビットとなっています。
concrete_func.inputs[0].set_shape([1, 128, 128, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

また、2021/9/12の時点のcondaでインストールされるTensorFlowで変換できませんでした。
解決方法は、StackOverflowに記載されていました。
これに従い「tensorflow-estimator」のバージョンを2.3.0にします。

pip install --upgrade tensorflow-estimator==2.3.0

ソースコードを「er2lite.py」で保存し、コマンドプロンプトなどで以下を実行します。

python er2lite.py

正常に終了すると、「ESRGAN.tflite」が作られます。
これで、Androidで使えるなら楽なのですが、メタデータを付加しないと容易に使えません。

2. 作成したtfliteファイルにメタデータを追加する

メタデータの付加にはtflite_supportを使います。以下でインストールできます。

pip install tflite-support

メタデータの追加方法は公式ページに記載があるのですが、必要な情報がほとんど掲載されていません。
Colabotryに掲載されているサンプルコードや、tflite-supportのソースコード内の情報を頼りに変換します。

試行錯誤で書いたメタデータの追加コードが以下です。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags

import tensorflow as tf

from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

FLAGS = flags.FLAGS
flags.DEFINE_string("model_file", None,
                    "Path and file name to the TFLite model file.")
flags.DEFINE_string("export_directory", None,
                    "Path to save the TFLite model files with metadata.")

class MetadataPopulatorForModel(object):

  def __init__(self, model_file):
    self.model_file = model_file
    self.metadata_buf = None

  def populate(self):
    self._create_metadata()
    self._populate_metadata()

  def _create_metadata(self):

    # Creates model info.
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.name = "esrgan-tf2" 
    model_meta.description = ("esrgan-tf2")
    model_meta.version = "v1"
    model_meta.author = "TensorFlow"
    model_meta.license = ("MIT License."
                          "https://opensource.org/licenses/MIT")

    # Creates info for the input, image.
    input_image_meta = _metadata_fb.TensorMetadataT()
    input_image_meta.name = "selfie_image"
    input_image_meta.description = (
        "The expected image is with three channels "
        "(red, blue, and green) per pixel. Each value in the tensor is between"
        " 0 and 255.")
    input_image_meta.content = _metadata_fb.ContentT()
    input_image_meta.content.contentProperties = (
        _metadata_fb.ImagePropertiesT())
    input_image_meta.content.contentProperties.colorSpace = (
        _metadata_fb.ColorSpaceType.RGB)
    input_image_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.ImageProperties)
    input_image_normalization = _metadata_fb.ProcessUnitT()
    input_image_normalization.optionsType = (
        _metadata_fb.ProcessUnitOptions.NormalizationOptions)
    input_image_normalization.options = _metadata_fb.NormalizationOptionsT()
    input_image_normalization.options.mean = [0]
    input_image_normalization.options.std = [1]
    input_image_meta.processUnits = [input_image_normalization]
    input_image_stats = _metadata_fb.StatsT()
    input_image_stats.max = [255.0]
    input_image_stats.min = [0.0]
    input_image_meta.stats = input_image_stats


    # Creates output info, anime image
    output_image_meta = _metadata_fb.TensorMetadataT()
    output_image_meta.name = "imagef"
    output_image_meta.description = "super scaled image"
    output_image_meta.content = _metadata_fb.ContentT()
    output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
    output_image_meta.content.contentProperties.colorSpace = (
        _metadata_fb.ColorSpaceType.RGB)
    output_image_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.ImageProperties)
    output_image_normalization = _metadata_fb.ProcessUnitT()
    output_image_normalization.optionsType = (
        _metadata_fb.ProcessUnitOptions.NormalizationOptions)
    output_image_normalization.options = _metadata_fb.NormalizationOptionsT()
    output_image_normalization.options.mean = [0]
    output_image_normalization.options.std = [1]  # 1/127.5
    output_image_meta.processUnits = [output_image_normalization]
    output_image_stats = _metadata_fb.StatsT()
    output_image_stats.max = [255.0]
    output_image_stats.min = [0.0]
    output_image_meta.stats = output_image_stats

    # Creates subgraph info.
    subgraph = _metadata_fb.SubGraphMetadataT()
    subgraph.inputTensorMetadata = [input_image_meta] # Updated by Margaret
    subgraph.outputTensorMetadata = [output_image_meta] # Updated by Margaret
    model_meta.subgraphMetadata = [subgraph]

    b = flatbuffers.Builder(0)
    b.Finish(
        model_meta.Pack(b),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    self.metadata_buf = b.Output()

  def _populate_metadata(self):
    """Populates metadata to the model file."""
    populator = _metadata.MetadataPopulator.with_model_file(self.model_file)
    populator.load_metadata_buffer(self.metadata_buf)
    populator.populate()


def populate_metadata(model_file):
  """Populates the metadata using the populator specified.

  Args:
      model_file: valid path to the model file.
      model_type: a type defined in StyleTransferModelType .
  """

  # Populates metadata for the model.
  model_file_basename = os.path.basename(model_file)
  export_path = os.path.join(FLAGS.export_directory, model_file_basename)
  tf.io.gfile.copy(model_file, export_path, overwrite=True)

  populator = MetadataPopulatorForModel(export_path)
  populator.populate()

  # Displays the metadata that was just populated into the tflite model.
  displayer = _metadata.MetadataDisplayer.with_model_file(export_path)
  export_json_file = os.path.join(
      FLAGS.export_directory,
      os.path.splitext(model_file_basename)[0] + ".json")
  json_file = displayer.get_metadata_json()
  with open(export_json_file, "w") as f:
    f.write(json_file)
  print("Finished populating metadata and associated file to the model:")
  print(export_path)
  print("The metadata json file has been saved to:")
  print(
      os.path.join(FLAGS.export_directory,
                   os.path.splitext(model_file_basename)[0] + ".json"))


def main(_):
  populate_metadata(FLAGS.model_file)


if __name__ == "__main__":
  app.run(main)

このプログラムを以下のように使ってメタデータを追記します。

python metadata_writer.py --model_file=./ESRGAN.tflite --export_directory=./with_meta

これでモデルデータの完成です。

3. AndroidStudioで実行コードを書く

ここまでくれば、Androidで使うのは簡単(?)です。
AndroidStudioでモデルファイルを取り込んだ後は、以下のようなコードで画像を変換できます。

    private fun getHiResAsync(bitmap: Bitmap): Deferred<Bitmap> =
        // use async() to create a coroutine in an IO optimized Dispatcher for model inference
        coroutineScope.async(Dispatchers.IO) {

            val compatList = CompatibilityList()

            val options = if(compatList.isDelegateSupportedOnThisDevice  && AppPreference.getUseGPU(context)) {
                // if the device has a supported GPU, add the GPU delegate
                Model.Options.Builder().setDevice(Model.Device.GPU).build()
            } else {
                // if the GPU is not supported, run on 4 threads
                Model.Options.Builder().setNumThreads(4).build()
            }

            //GPUだと動作しないので、CPUのみで動かす
            val model = ESRGAN.newInstance(requireContext(), options)

            // Creates inputs for reference.
            val selfieImage = TensorImage.fromBitmap(bitmap)

            // Runs model inference and gets result.
            val outputs = model.process(selfieImage)
            val imagef = outputs.imagefAsTensorImage
            val imagefBitmap = imagef.bitmap

            // Releases model resources if no longer used.
            model.close()

            return@async imagefBitmap
        }

ここまで頑張る価値があるかは人それぞれですが・・・。

0 件のコメント :

コメントを投稿