# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Vision models export utility function for serving/inference.""" import os from typing import Optional, List, Union, Text, Dict from absl import logging import tensorflow as tf, tf_keras from official.core import config_definitions as cfg from official.core import export_base from official.core import train_utils from official.vision import configs from official.vision.serving import detection from official.vision.serving import image_classification from official.vision.serving import semantic_segmentation from official.vision.serving import video_classification def export_inference_graph( input_type: str, batch_size: Optional[int], input_image_size: List[int], params: cfg.ExperimentConfig, checkpoint_path: str, export_dir: str, num_channels: Optional[int] = 3, export_module: Optional[export_base.ExportModule] = None, export_checkpoint_subdir: Optional[str] = None, export_saved_model_subdir: Optional[str] = None, save_options: Optional[tf.saved_model.SaveOptions] = None, log_model_flops_and_params: bool = False, checkpoint: Optional[tf.train.Checkpoint] = None, input_name: Optional[str] = None, function_keys: Optional[Union[List[Text], Dict[Text, Text]]] = None, add_tpu_function_alias: Optional[bool] = False, ): """Exports inference graph for the model specified in the exp config. Saved model is stored at export_dir/saved_model, checkpoint is saved at export_dir/checkpoint, and params is saved at export_dir/params.yaml. Args: input_type: One of `image_tensor`, `image_bytes`, `tf_example` or `tflite`. batch_size: 'int', or None. input_image_size: List or Tuple of height and width. params: Experiment params. checkpoint_path: Trained checkpoint path or directory. export_dir: Export directory path. num_channels: The number of input image channels. export_module: Optional export module to be used instead of using params to create one. If None, the params will be used to create an export module. export_checkpoint_subdir: Optional subdirectory under export_dir to store checkpoint. export_saved_model_subdir: Optional subdirectory under export_dir to store saved model. save_options: `SaveOptions` for `tf.saved_model.save`. log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt and model parameters to model_params.txt. checkpoint: An optional tf.train.Checkpoint. If provided, the export module will use it to read the weights. input_name: The input tensor name, default at `None` which produces input tensor name `inputs`. function_keys: a list of string keys to retrieve pre-defined serving signatures. The signaute keys will be set with defaults. If a dictionary is provided, the values will be used as signature keys. add_tpu_function_alias: Whether to add TPU function alias so that it can be converted to a TPU compatible saved model later. Default is False. """ if export_checkpoint_subdir: output_checkpoint_directory = os.path.join( export_dir, export_checkpoint_subdir) else: output_checkpoint_directory = None if export_saved_model_subdir: output_saved_model_directory = os.path.join( export_dir, export_saved_model_subdir) else: output_saved_model_directory = export_dir # TODO(arashwan): Offers a direct path to use ExportModule with Task objects. if not export_module: if isinstance(params.task, configs.image_classification.ImageClassificationTask): export_module = image_classification.ClassificationModule( params=params, batch_size=batch_size, input_image_size=input_image_size, input_type=input_type, num_channels=num_channels, input_name=input_name) elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance( params.task, configs.maskrcnn.MaskRCNNTask): export_module = detection.DetectionModule( params=params, batch_size=batch_size, input_image_size=input_image_size, input_type=input_type, num_channels=num_channels, input_name=input_name) elif isinstance(params.task, configs.semantic_segmentation.SemanticSegmentationTask): export_module = semantic_segmentation.SegmentationModule( params=params, batch_size=batch_size, input_image_size=input_image_size, input_type=input_type, num_channels=num_channels, input_name=input_name) elif isinstance(params.task, configs.video_classification.VideoClassificationTask): export_module = video_classification.VideoClassificationModule( params=params, batch_size=batch_size, input_image_size=input_image_size, input_type=input_type, num_channels=num_channels, input_name=input_name) else: raise ValueError('Export module not implemented for {} task.'.format( type(params.task))) if add_tpu_function_alias: if input_type == 'image_tensor': inference_func = export_module.inference_from_image_tensors elif input_type == 'image_bytes': inference_func = export_module.inference_from_image_bytes elif input_type == 'tf_example': inference_func = export_module.inference_from_tf_example else: raise ValueError( 'add_tpu_function_alias is only allowed for input_type of:' ' image_tensor, image_bytes, tf_example.' ) save_options = tf.saved_model.SaveOptions( function_aliases={ 'tpu_candidate': inference_func, } ) export_base.export( export_module, function_keys=function_keys if function_keys else [input_type], export_savedmodel_dir=output_saved_model_directory, checkpoint=checkpoint, checkpoint_path=checkpoint_path, timestamped=False, save_options=save_options) if output_checkpoint_directory: ckpt = tf.train.Checkpoint(model=export_module.model) ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt')) train_utils.serialize_config(params, export_dir) if log_model_flops_and_params: inputs_kwargs = None if isinstance( params.task, (configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)): # We need to create inputs_kwargs argument to specify the input shapes for # subclass model that overrides model.call to take multiple inputs, # e.g., RetinaNet model. inputs_kwargs = { 'images': tf.TensorSpec([1] + input_image_size + [num_channels], tf.float32), 'image_shape': tf.TensorSpec([1, 2], tf.float32) } dummy_inputs = { k: tf.ones(v.shape.as_list(), tf.float32) for k, v in inputs_kwargs.items() } # Must do forward pass to build the model. export_module.model(**dummy_inputs) else: logging.info( 'Logging model flops and params not implemented for %s task.', type(params.task)) return train_utils.try_count_flops(export_module.model, inputs_kwargs, os.path.join(export_dir, 'model_flops.txt')) train_utils.write_model_params(export_module.model, os.path.join(export_dir, 'model_params.txt'))