Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Added functionality to load from pipeline config for lstm framework.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow.compat.v1 as tf | |
from google.protobuf import text_format | |
from lstm_object_detection.protos import input_reader_google_pb2 # pylint: disable=unused-import | |
from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2 | |
from object_detection.protos import pipeline_pb2 | |
from object_detection.utils import config_util | |
def get_configs_from_pipeline_file(pipeline_config_path): | |
"""Reads configuration from a pipeline_pb2.TrainEvalPipelineConfig. | |
Args: | |
pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text | |
proto. | |
Returns: | |
Dictionary of configuration objects. Keys are `model`, `train_config`, | |
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`. | |
Value are the corresponding config objects. | |
""" | |
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | |
with tf.gfile.GFile(pipeline_config_path, "r") as f: | |
proto_str = f.read() | |
text_format.Merge(proto_str, pipeline_config) | |
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) | |
if pipeline_config.HasExtension(internal_pipeline_pb2.lstm_model): | |
configs["lstm_model"] = pipeline_config.Extensions[ | |
internal_pipeline_pb2.lstm_model] | |
return configs | |
def create_pipeline_proto_from_configs(configs): | |
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary. | |
This function nearly performs the inverse operation of | |
get_configs_from_pipeline_file(). Instead of returning a file path, it returns | |
a `TrainEvalPipelineConfig` object. | |
Args: | |
configs: Dictionary of configs. See get_configs_from_pipeline_file(). | |
Returns: | |
A fully populated pipeline_pb2.TrainEvalPipelineConfig. | |
""" | |
pipeline_config = config_util.create_pipeline_proto_from_configs(configs) | |
if "lstm_model" in configs: | |
pipeline_config.Extensions[internal_pipeline_pb2.lstm_model].CopyFrom( | |
configs["lstm_model"]) | |
return pipeline_config | |
def get_configs_from_multiple_files(model_config_path="", | |
train_config_path="", | |
train_input_config_path="", | |
eval_config_path="", | |
eval_input_config_path="", | |
lstm_config_path=""): | |
"""Reads training configuration from multiple config files. | |
Args: | |
model_config_path: Path to model_pb2.DetectionModel. | |
train_config_path: Path to train_pb2.TrainConfig. | |
train_input_config_path: Path to input_reader_pb2.InputReader. | |
eval_config_path: Path to eval_pb2.EvalConfig. | |
eval_input_config_path: Path to input_reader_pb2.InputReader. | |
lstm_config_path: Path to pipeline_pb2.LstmModel. | |
Returns: | |
Dictionary of configuration objects. Keys are `model`, `train_config`, | |
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`. | |
Key/Values are returned only for valid (non-empty) strings. | |
""" | |
configs = config_util.get_configs_from_multiple_files( | |
model_config_path=model_config_path, | |
train_config_path=train_config_path, | |
train_input_config_path=train_input_config_path, | |
eval_config_path=eval_config_path, | |
eval_input_config_path=eval_input_config_path) | |
if lstm_config_path: | |
lstm_config = internal_pipeline_pb2.LstmModel() | |
with tf.gfile.GFile(lstm_config_path, "r") as f: | |
text_format.Merge(f.read(), lstm_config) | |
configs["lstm_model"] = lstm_config | |
return configs | |