NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Added functionality to load from pipeline config for lstm framework."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from lstm_object_detection.protos import input_reader_google_pb2 # pylint: disable=unused-import
from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2
from object_detection.protos import pipeline_pb2
from object_detection.utils import config_util
def get_configs_from_pipeline_file(pipeline_config_path):
"""Reads configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Args:
pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text
proto.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Value are the corresponding config objects.
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(pipeline_config_path, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
if pipeline_config.HasExtension(internal_pipeline_pb2.lstm_model):
configs["lstm_model"] = pipeline_config.Extensions[
internal_pipeline_pb2.lstm_model]
return configs
def create_pipeline_proto_from_configs(configs):
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary.
This function nearly performs the inverse operation of
get_configs_from_pipeline_file(). Instead of returning a file path, it returns
a `TrainEvalPipelineConfig` object.
Args:
configs: Dictionary of configs. See get_configs_from_pipeline_file().
Returns:
A fully populated pipeline_pb2.TrainEvalPipelineConfig.
"""
pipeline_config = config_util.create_pipeline_proto_from_configs(configs)
if "lstm_model" in configs:
pipeline_config.Extensions[internal_pipeline_pb2.lstm_model].CopyFrom(
configs["lstm_model"])
return pipeline_config
def get_configs_from_multiple_files(model_config_path="",
train_config_path="",
train_input_config_path="",
eval_config_path="",
eval_input_config_path="",
lstm_config_path=""):
"""Reads training configuration from multiple config files.
Args:
model_config_path: Path to model_pb2.DetectionModel.
train_config_path: Path to train_pb2.TrainConfig.
train_input_config_path: Path to input_reader_pb2.InputReader.
eval_config_path: Path to eval_pb2.EvalConfig.
eval_input_config_path: Path to input_reader_pb2.InputReader.
lstm_config_path: Path to pipeline_pb2.LstmModel.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Key/Values are returned only for valid (non-empty) strings.
"""
configs = config_util.get_configs_from_multiple_files(
model_config_path=model_config_path,
train_config_path=train_config_path,
train_input_config_path=train_input_config_path,
eval_config_path=eval_config_path,
eval_input_config_path=eval_input_config_path)
if lstm_config_path:
lstm_config = internal_pipeline_pb2.LstmModel()
with tf.gfile.GFile(lstm_config_path, "r") as f:
text_format.Merge(f.read(), lstm_config)
configs["lstm_model"] = lstm_config
return configs