Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Added functionality to load from pipeline config for lstm framework.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import tensorflow.compat.v1 as tf | |
| from google.protobuf import text_format | |
| from lstm_object_detection.protos import input_reader_google_pb2 # pylint: disable=unused-import | |
| from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2 | |
| from object_detection.protos import pipeline_pb2 | |
| from object_detection.utils import config_util | |
| def get_configs_from_pipeline_file(pipeline_config_path): | |
| """Reads configuration from a pipeline_pb2.TrainEvalPipelineConfig. | |
| Args: | |
| pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text | |
| proto. | |
| Returns: | |
| Dictionary of configuration objects. Keys are `model`, `train_config`, | |
| `train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`. | |
| Value are the corresponding config objects. | |
| """ | |
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | |
| with tf.gfile.GFile(pipeline_config_path, "r") as f: | |
| proto_str = f.read() | |
| text_format.Merge(proto_str, pipeline_config) | |
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) | |
| if pipeline_config.HasExtension(internal_pipeline_pb2.lstm_model): | |
| configs["lstm_model"] = pipeline_config.Extensions[ | |
| internal_pipeline_pb2.lstm_model] | |
| return configs | |
| def create_pipeline_proto_from_configs(configs): | |
| """Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary. | |
| This function nearly performs the inverse operation of | |
| get_configs_from_pipeline_file(). Instead of returning a file path, it returns | |
| a `TrainEvalPipelineConfig` object. | |
| Args: | |
| configs: Dictionary of configs. See get_configs_from_pipeline_file(). | |
| Returns: | |
| A fully populated pipeline_pb2.TrainEvalPipelineConfig. | |
| """ | |
| pipeline_config = config_util.create_pipeline_proto_from_configs(configs) | |
| if "lstm_model" in configs: | |
| pipeline_config.Extensions[internal_pipeline_pb2.lstm_model].CopyFrom( | |
| configs["lstm_model"]) | |
| return pipeline_config | |
| def get_configs_from_multiple_files(model_config_path="", | |
| train_config_path="", | |
| train_input_config_path="", | |
| eval_config_path="", | |
| eval_input_config_path="", | |
| lstm_config_path=""): | |
| """Reads training configuration from multiple config files. | |
| Args: | |
| model_config_path: Path to model_pb2.DetectionModel. | |
| train_config_path: Path to train_pb2.TrainConfig. | |
| train_input_config_path: Path to input_reader_pb2.InputReader. | |
| eval_config_path: Path to eval_pb2.EvalConfig. | |
| eval_input_config_path: Path to input_reader_pb2.InputReader. | |
| lstm_config_path: Path to pipeline_pb2.LstmModel. | |
| Returns: | |
| Dictionary of configuration objects. Keys are `model`, `train_config`, | |
| `train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`. | |
| Key/Values are returned only for valid (non-empty) strings. | |
| """ | |
| configs = config_util.get_configs_from_multiple_files( | |
| model_config_path=model_config_path, | |
| train_config_path=train_config_path, | |
| train_input_config_path=train_input_config_path, | |
| eval_config_path=eval_config_path, | |
| eval_input_config_path=eval_input_config_path) | |
| if lstm_config_path: | |
| lstm_config = internal_pipeline_pb2.LstmModel() | |
| with tf.gfile.GFile(lstm_config_path, "r") as f: | |
| text_format.Merge(f.read(), lstm_config) | |
| configs["lstm_model"] = lstm_config | |
| return configs | |