TensorFlow Hubで公開されている超解像度モデル「esrgan-tf2」をAndroidで使ってみます。
手順をは以下になります。
1. モデルをtflite形式に変換する
2. 作成したtfliteファイルにメタデータを追加する
3. AndroidStudioで実行コードを書く
1.~3.の手順のどれもが参考記事が少なく手間取りますので、記録しておく次第です。
1. モデルをtflite形式に変換する
幸いにもTensorFlow Liteで動かすためのチュートリアルをGoogleが公開しています。
この記事はAndroidで動かすための記事ではありませんがtflite形式を作るまでは参考になります。
この記事にした以外tflite形式を作りましょう。
まず、必要なPythonライブラリをインストールします。
pip install matplotlib tensorflow tensorflow-hub
以下のコードを使いtflite形式を作ります。
import tensorflow as tf import tensorflow_hub as hub print(tf.__version__) model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1") concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] #[1, 128, 128, 3]が、入力画像サイズを示しています。この場合、128x128、24ビットとなっています。 concrete_func.inputs[0].set_shape([1, 128, 128, 3]) converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() # Save the TF Lite model. with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f: f.write(tflite_model)
また、2021/9/12の時点のcondaでインストールされるTensorFlowで変換できませんでした。
解決方法は、StackOverflowに記載されていました。
これに従い「tensorflow-estimator」のバージョンを2.3.0にします。
pip install --upgrade tensorflow-estimator==2.3.0
ソースコードを「er2lite.py」で保存し、コマンドプロンプトなどで以下を実行します。
python er2lite.py
正常に終了すると、「ESRGAN.tflite」が作られます。
これで、Androidで使えるなら楽なのですが、メタデータを付加しないと容易に使えません。
2. 作成したtfliteファイルにメタデータを追加する
メタデータの付加にはtflite_supportを使います。以下でインストールできます。
pip install tflite-support
メタデータの追加方法は公式ページに記載があるのですが、必要な情報がほとんど掲載されていません。
Colabotryに掲載されているサンプルコードや、tflite-supportのソースコード内の情報を頼りに変換します。
試行錯誤で書いたメタデータの追加コードが以下です。
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os from absl import app from absl import flags import tensorflow as tf from tflite_support import flatbuffers from tflite_support import metadata as _metadata from tflite_support import metadata_schema_py_generated as _metadata_fb FLAGS = flags.FLAGS flags.DEFINE_string("model_file", None, "Path and file name to the TFLite model file.") flags.DEFINE_string("export_directory", None, "Path to save the TFLite model files with metadata.") class MetadataPopulatorForModel(object): def __init__(self, model_file): self.model_file = model_file self.metadata_buf = None def populate(self): self._create_metadata() self._populate_metadata() def _create_metadata(self): # Creates model info. model_meta = _metadata_fb.ModelMetadataT() model_meta.name = "esrgan-tf2" model_meta.description = ("esrgan-tf2") model_meta.version = "v1" model_meta.author = "TensorFlow" model_meta.license = ("MIT License." "https://opensource.org/licenses/MIT") # Creates info for the input, image. input_image_meta = _metadata_fb.TensorMetadataT() input_image_meta.name = "selfie_image" input_image_meta.description = ( "The expected image is with three channels " "(red, blue, and green) per pixel. Each value in the tensor is between" " 0 and 255.") input_image_meta.content = _metadata_fb.ContentT() input_image_meta.content.contentProperties = ( _metadata_fb.ImagePropertiesT()) input_image_meta.content.contentProperties.colorSpace = ( _metadata_fb.ColorSpaceType.RGB) input_image_meta.content.contentPropertiesType = ( _metadata_fb.ContentProperties.ImageProperties) input_image_normalization = _metadata_fb.ProcessUnitT() input_image_normalization.optionsType = ( _metadata_fb.ProcessUnitOptions.NormalizationOptions) input_image_normalization.options = _metadata_fb.NormalizationOptionsT() input_image_normalization.options.mean = [0] input_image_normalization.options.std = [1] input_image_meta.processUnits = [input_image_normalization] input_image_stats = _metadata_fb.StatsT() input_image_stats.max = [255.0] input_image_stats.min = [0.0] input_image_meta.stats = input_image_stats # Creates output info, anime image output_image_meta = _metadata_fb.TensorMetadataT() output_image_meta.name = "imagef" output_image_meta.description = "super scaled image" output_image_meta.content = _metadata_fb.ContentT() output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() output_image_meta.content.contentProperties.colorSpace = ( _metadata_fb.ColorSpaceType.RGB) output_image_meta.content.contentPropertiesType = ( _metadata_fb.ContentProperties.ImageProperties) output_image_normalization = _metadata_fb.ProcessUnitT() output_image_normalization.optionsType = ( _metadata_fb.ProcessUnitOptions.NormalizationOptions) output_image_normalization.options = _metadata_fb.NormalizationOptionsT() output_image_normalization.options.mean = [0] output_image_normalization.options.std = [1] # 1/127.5 output_image_meta.processUnits = [output_image_normalization] output_image_stats = _metadata_fb.StatsT() output_image_stats.max = [255.0] output_image_stats.min = [0.0] output_image_meta.stats = output_image_stats # Creates subgraph info. subgraph = _metadata_fb.SubGraphMetadataT() subgraph.inputTensorMetadata = [input_image_meta] # Updated by Margaret subgraph.outputTensorMetadata = [output_image_meta] # Updated by Margaret model_meta.subgraphMetadata = [subgraph] b = flatbuffers.Builder(0) b.Finish( model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) self.metadata_buf = b.Output() def _populate_metadata(self): """Populates metadata to the model file.""" populator = _metadata.MetadataPopulator.with_model_file(self.model_file) populator.load_metadata_buffer(self.metadata_buf) populator.populate() def populate_metadata(model_file): """Populates the metadata using the populator specified. Args: model_file: valid path to the model file. model_type: a type defined in StyleTransferModelType . """ # Populates metadata for the model. model_file_basename = os.path.basename(model_file) export_path = os.path.join(FLAGS.export_directory, model_file_basename) tf.io.gfile.copy(model_file, export_path, overwrite=True) populator = MetadataPopulatorForModel(export_path) populator.populate() # Displays the metadata that was just populated into the tflite model. displayer = _metadata.MetadataDisplayer.with_model_file(export_path) export_json_file = os.path.join( FLAGS.export_directory, os.path.splitext(model_file_basename)[0] + ".json") json_file = displayer.get_metadata_json() with open(export_json_file, "w") as f: f.write(json_file) print("Finished populating metadata and associated file to the model:") print(export_path) print("The metadata json file has been saved to:") print( os.path.join(FLAGS.export_directory, os.path.splitext(model_file_basename)[0] + ".json")) def main(_): populate_metadata(FLAGS.model_file) if __name__ == "__main__": app.run(main)
このプログラムを以下のように使ってメタデータを追記します。
python metadata_writer.py --model_file=./ESRGAN.tflite --export_directory=./with_meta
これでモデルデータの完成です。
3. AndroidStudioで実行コードを書く
ここまでくれば、Androidで使うのは簡単(?)です。
AndroidStudioでモデルファイルを取り込んだ後は、以下のようなコードで画像を変換できます。
private fun getHiResAsync(bitmap: Bitmap): Deferred<Bitmap> = // use async() to create a coroutine in an IO optimized Dispatcher for model inference coroutineScope.async(Dispatchers.IO) { val compatList = CompatibilityList() val options = if(compatList.isDelegateSupportedOnThisDevice && AppPreference.getUseGPU(context)) { // if the device has a supported GPU, add the GPU delegate Model.Options.Builder().setDevice(Model.Device.GPU).build() } else { // if the GPU is not supported, run on 4 threads Model.Options.Builder().setNumThreads(4).build() } //GPUだと動作しないので、CPUのみで動かす val model = ESRGAN.newInstance(requireContext(), options) // Creates inputs for reference. val selfieImage = TensorImage.fromBitmap(bitmap) // Runs model inference and gets result. val outputs = model.process(selfieImage) val imagef = outputs.imagefAsTensorImage val imagefBitmap = imagef.bitmap // Releases model resources if no longer used. model.close() return@async imagefBitmap }
ここまで頑張る価値があるかは人それぞれですが・・・。
0 件のコメント :
コメントを投稿