前回の記事で、使用したネットワークをtensorflow liteに変換しAndoroidで使用してみました。
実施た手順は以下のようになります。
1.python上でテストを実施するのと同じ手順でネットワークを構築し、学習済みのモデルを読み込む
2.TFLiteConverterライブラリを使用して、ネットワークをtflite形式で出力する
3.tflite形式のデータにtflite_supportライブラリを使ってメタデータを追加する
では、順に手順を追っていきましょう。
1.~2. python上でテストを実施するのと同じ手順でネットワークを構築し、tflite形式で出力する
以下のコードで実施します。
def convertToTFlite(self, args):
"""Test cyclegan"""
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
if args.which_direction == 'AtoB':
sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
elif args.which_direction == 'BtoA':
sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
else:
raise Exception('--which_direction must be AtoB or BtoA')
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
out_var, in_var = (self.testB, self.test_A) if args.which_direction == 'AtoB' else (
self.testA, self.test_B)
converter = lite.TFLiteConverter.from_session(self.sess, [in_var], [out_var])
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
3.tflite形式のデータにtflite_supportライブラリを使ってメタデータを追加する
以下のようなコードを使用して追加します。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
# Updated by Margaret
FLAGS = flags.FLAGS
flags.DEFINE_string("model_file", None,
"Path and file name to the TFLite model file.")
flags.DEFINE_string("export_directory", None,
"Path to save the TFLite model files with metadata.")
class MetadataPopulatorForGANModel(object):
"""Populates the metadata for the selfie2anime model."""
def __init__(self, model_file):
self.model_file = model_file
self.metadata_buf = None
def populate(self):
"""Creates metadata and thesn populates it for a style transfer model."""
self._create_metadata()
self._populate_metadata()
def _create_metadata(self):
"""Creates the metadata for the selfie2anime model."""
# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "gendersSwapMtoF"
model_meta.description = ("Gender swap male to female.")
model_meta.version = "v1"
model_meta.author = "TensorFlow"
model_meta.license = ("Apache License. Version 2.0 "
"http://www.apache.org/licenses/LICENSE-2.0.")
# Creates info for the input, selfie image.
input_image_meta = _metadata_fb.TensorMetadataT()
input_image_meta.name = "selfie_image"
input_image_meta.description = (
"The expected image is 256 x 256, with three channels "
"(red, blue, and green) per pixel. Each value in the tensor is between"
" -1 and 1.")
input_image_meta.content = _metadata_fb.ContentT()
input_image_meta.content.contentProperties = (
_metadata_fb.ImagePropertiesT())
input_image_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
input_image_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
input_image_normalization = _metadata_fb.ProcessUnitT()
input_image_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_image_normalization.options = _metadata_fb.NormalizationOptionsT()
input_image_normalization.options.mean = [127.5]
input_image_normalization.options.std = [127.5]
input_image_meta.processUnits = [input_image_normalization]
input_image_stats = _metadata_fb.StatsT()
input_image_stats.max = [1.0]
input_image_stats.min = [-1.0]
input_image_meta.stats = input_image_stats
# Creates output info, anime image
output_image_meta = _metadata_fb.TensorMetadataT()
output_image_meta.name = "gender_swap_image_m_to_f"
output_image_meta.description = "Image styled."
output_image_meta.content = _metadata_fb.ContentT()
output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
output_image_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
output_image_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
output_image_normalization = _metadata_fb.ProcessUnitT()
output_image_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
output_image_normalization.options = _metadata_fb.NormalizationOptionsT()
output_image_normalization.options.mean = [-1.0]
output_image_normalization.options.std = [0.007843137254902] # 1/127.5
output_image_meta.processUnits = [output_image_normalization]
output_image_stats = _metadata_fb.StatsT()
output_image_stats.max = [1.0]
output_image_stats.min = [0.0]
output_image_meta.stats = output_image_stats
# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_image_meta] # Updated by Margaret
subgraph.outputTensorMetadata = [output_image_meta] # Updated by Margaret
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
self.metadata_buf = b.Output()
def _populate_metadata(self):
"""Populates metadata to the model file."""
populator = _metadata.MetadataPopulator.with_model_file(self.model_file)
populator.load_metadata_buffer(self.metadata_buf)
populator.populate()
def populate_metadata(model_file):
"""Populates the metadata using the populator specified.
Args:
model_file: valid path to the model file.
model_type: a type defined in StyleTransferModelType .
"""
# Populates metadata for the model.
model_file_basename = os.path.basename(model_file)
export_path = os.path.join(FLAGS.export_directory, model_file_basename)
tf.io.gfile.copy(model_file, export_path, overwrite=True)
populator = MetadataPopulatorForGANModel(export_path) # Updated by Margaret
populator.populate()
# Displays the metadata that was just populated into the tflite model.
displayer = _metadata.MetadataDisplayer.with_model_file(export_path)
export_json_file = os.path.join(
FLAGS.export_directory,
os.path.splitext(model_file_basename)[0] + ".json")
json_file = displayer.get_metadata_json()
with open(export_json_file, "w") as f:
f.write(json_file)
print("Finished populating metadata and associated file to the model:")
print(export_path)
print("The metadata json file has been saved to:")
print(
os.path.join(FLAGS.export_directory,
os.path.splitext(model_file_basename)[0] + ".json"))
def main(_):
"""Writes metadata to the selfie2anime model."""
populate_metadata(FLAGS.model_file)
if __name__ == "__main__":
app.run(main)

0 件のコメント :
コメントを投稿