|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Functions for reading and updating configuration files.""" |
|
|
|
import os |
|
import tensorflow as tf |
|
|
|
from google.protobuf import text_format |
|
|
|
from tensorflow.python.lib.io import file_io |
|
|
|
from object_detection.protos import eval_pb2 |
|
from object_detection.protos import graph_rewriter_pb2 |
|
from object_detection.protos import input_reader_pb2 |
|
from object_detection.protos import model_pb2 |
|
from object_detection.protos import pipeline_pb2 |
|
from object_detection.protos import train_pb2 |
|
|
|
|
|
def get_image_resizer_config(model_config): |
|
"""Returns the image resizer config from a model config. |
|
|
|
Args: |
|
model_config: A model_pb2.DetectionModel. |
|
|
|
Returns: |
|
An image_resizer_pb2.ImageResizer. |
|
|
|
Raises: |
|
ValueError: If the model type is not recognized. |
|
""" |
|
meta_architecture = model_config.WhichOneof("model") |
|
if meta_architecture == "faster_rcnn": |
|
return model_config.faster_rcnn.image_resizer |
|
if meta_architecture == "ssd": |
|
return model_config.ssd.image_resizer |
|
|
|
raise ValueError("Unknown model type: {}".format(meta_architecture)) |
|
|
|
|
|
def get_spatial_image_size(image_resizer_config): |
|
"""Returns expected spatial size of the output image from a given config. |
|
|
|
Args: |
|
image_resizer_config: An image_resizer_pb2.ImageResizer. |
|
|
|
Returns: |
|
A list of two integers of the form [height, width]. `height` and `width` are |
|
set -1 if they cannot be determined during graph construction. |
|
|
|
Raises: |
|
ValueError: If the model type is not recognized. |
|
""" |
|
if image_resizer_config.HasField("fixed_shape_resizer"): |
|
return [ |
|
image_resizer_config.fixed_shape_resizer.height, |
|
image_resizer_config.fixed_shape_resizer.width |
|
] |
|
if image_resizer_config.HasField("keep_aspect_ratio_resizer"): |
|
if image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension: |
|
return [image_resizer_config.keep_aspect_ratio_resizer.max_dimension] * 2 |
|
else: |
|
return [-1, -1] |
|
if image_resizer_config.HasField("identity_resizer"): |
|
return [-1, -1] |
|
raise ValueError("Unknown image resizer type.") |
|
|
|
|
|
def get_configs_from_pipeline_file(pipeline_config_path, config_override=None): |
|
"""Reads config from a file containing pipeline_pb2.TrainEvalPipelineConfig. |
|
|
|
Args: |
|
pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text |
|
proto. |
|
config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to |
|
override pipeline_config_path. |
|
|
|
Returns: |
|
Dictionary of configuration objects. Keys are `model`, `train_config`, |
|
`train_input_config`, `eval_config`, `eval_input_config`. Value are the |
|
corresponding config objects. |
|
""" |
|
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() |
|
with tf.gfile.GFile(pipeline_config_path, "r") as f: |
|
proto_str = f.read() |
|
text_format.Merge(proto_str, pipeline_config) |
|
if config_override: |
|
text_format.Merge(config_override, pipeline_config) |
|
return create_configs_from_pipeline_proto(pipeline_config) |
|
|
|
|
|
def create_configs_from_pipeline_proto(pipeline_config): |
|
"""Creates a configs dictionary from pipeline_pb2.TrainEvalPipelineConfig. |
|
|
|
Args: |
|
pipeline_config: pipeline_pb2.TrainEvalPipelineConfig proto object. |
|
|
|
Returns: |
|
Dictionary of configuration objects. Keys are `model`, `train_config`, |
|
`train_input_config`, `eval_config`, `eval_input_configs`. Value are |
|
the corresponding config objects or list of config objects (only for |
|
eval_input_configs). |
|
""" |
|
configs = {} |
|
configs["model"] = pipeline_config.model |
|
configs["train_config"] = pipeline_config.train_config |
|
configs["train_input_config"] = pipeline_config.train_input_reader |
|
configs["eval_config"] = pipeline_config.eval_config |
|
configs["eval_input_configs"] = pipeline_config.eval_input_reader |
|
|
|
|
|
if configs["eval_input_configs"]: |
|
configs["eval_input_config"] = configs["eval_input_configs"][0] |
|
if pipeline_config.HasField("graph_rewriter"): |
|
configs["graph_rewriter_config"] = pipeline_config.graph_rewriter |
|
|
|
return configs |
|
|
|
|
|
def get_graph_rewriter_config_from_file(graph_rewriter_config_file): |
|
"""Parses config for graph rewriter. |
|
|
|
Args: |
|
graph_rewriter_config_file: file path to the graph rewriter config. |
|
|
|
Returns: |
|
graph_rewriter_pb2.GraphRewriter proto |
|
""" |
|
graph_rewriter_config = graph_rewriter_pb2.GraphRewriter() |
|
with tf.gfile.GFile(graph_rewriter_config_file, "r") as f: |
|
text_format.Merge(f.read(), graph_rewriter_config) |
|
return graph_rewriter_config |
|
|
|
|
|
def create_pipeline_proto_from_configs(configs): |
|
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary. |
|
|
|
This function performs the inverse operation of |
|
create_configs_from_pipeline_proto(). |
|
|
|
Args: |
|
configs: Dictionary of configs. See get_configs_from_pipeline_file(). |
|
|
|
Returns: |
|
A fully populated pipeline_pb2.TrainEvalPipelineConfig. |
|
""" |
|
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() |
|
pipeline_config.model.CopyFrom(configs["model"]) |
|
pipeline_config.train_config.CopyFrom(configs["train_config"]) |
|
pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"]) |
|
pipeline_config.eval_config.CopyFrom(configs["eval_config"]) |
|
pipeline_config.eval_input_reader.extend(configs["eval_input_configs"]) |
|
if "graph_rewriter_config" in configs: |
|
pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"]) |
|
return pipeline_config |
|
|
|
|
|
def save_pipeline_config(pipeline_config, directory): |
|
"""Saves a pipeline config text file to disk. |
|
|
|
Args: |
|
pipeline_config: A pipeline_pb2.TrainEvalPipelineConfig. |
|
directory: The model directory into which the pipeline config file will be |
|
saved. |
|
""" |
|
if not file_io.file_exists(directory): |
|
file_io.recursive_create_dir(directory) |
|
pipeline_config_path = os.path.join(directory, "pipeline.config") |
|
config_text = text_format.MessageToString(pipeline_config) |
|
with tf.gfile.Open(pipeline_config_path, "wb") as f: |
|
tf.logging.info("Writing pipeline config file to %s", |
|
pipeline_config_path) |
|
f.write(config_text) |
|
|
|
|
|
def get_configs_from_multiple_files(model_config_path="", |
|
train_config_path="", |
|
train_input_config_path="", |
|
eval_config_path="", |
|
eval_input_config_path="", |
|
graph_rewriter_config_path=""): |
|
"""Reads training configuration from multiple config files. |
|
|
|
Args: |
|
model_config_path: Path to model_pb2.DetectionModel. |
|
train_config_path: Path to train_pb2.TrainConfig. |
|
train_input_config_path: Path to input_reader_pb2.InputReader. |
|
eval_config_path: Path to eval_pb2.EvalConfig. |
|
eval_input_config_path: Path to input_reader_pb2.InputReader. |
|
graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter. |
|
|
|
Returns: |
|
Dictionary of configuration objects. Keys are `model`, `train_config`, |
|
`train_input_config`, `eval_config`, `eval_input_config`. Key/Values are |
|
returned only for valid (non-empty) strings. |
|
""" |
|
configs = {} |
|
if model_config_path: |
|
model_config = model_pb2.DetectionModel() |
|
with tf.gfile.GFile(model_config_path, "r") as f: |
|
text_format.Merge(f.read(), model_config) |
|
configs["model"] = model_config |
|
|
|
if train_config_path: |
|
train_config = train_pb2.TrainConfig() |
|
with tf.gfile.GFile(train_config_path, "r") as f: |
|
text_format.Merge(f.read(), train_config) |
|
configs["train_config"] = train_config |
|
|
|
if train_input_config_path: |
|
train_input_config = input_reader_pb2.InputReader() |
|
with tf.gfile.GFile(train_input_config_path, "r") as f: |
|
text_format.Merge(f.read(), train_input_config) |
|
configs["train_input_config"] = train_input_config |
|
|
|
if eval_config_path: |
|
eval_config = eval_pb2.EvalConfig() |
|
with tf.gfile.GFile(eval_config_path, "r") as f: |
|
text_format.Merge(f.read(), eval_config) |
|
configs["eval_config"] = eval_config |
|
|
|
if eval_input_config_path: |
|
eval_input_config = input_reader_pb2.InputReader() |
|
with tf.gfile.GFile(eval_input_config_path, "r") as f: |
|
text_format.Merge(f.read(), eval_input_config) |
|
configs["eval_input_configs"] = [eval_input_config] |
|
|
|
if graph_rewriter_config_path: |
|
configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file( |
|
graph_rewriter_config_path) |
|
|
|
return configs |
|
|
|
|
|
def get_number_of_classes(model_config): |
|
"""Returns the number of classes for a detection model. |
|
|
|
Args: |
|
model_config: A model_pb2.DetectionModel. |
|
|
|
Returns: |
|
Number of classes. |
|
|
|
Raises: |
|
ValueError: If the model type is not recognized. |
|
""" |
|
meta_architecture = model_config.WhichOneof("model") |
|
if meta_architecture == "faster_rcnn": |
|
return model_config.faster_rcnn.num_classes |
|
if meta_architecture == "ssd": |
|
return model_config.ssd.num_classes |
|
|
|
raise ValueError("Expected the model to be one of 'faster_rcnn' or 'ssd'.") |
|
|
|
|
|
def get_optimizer_type(train_config): |
|
"""Returns the optimizer type for training. |
|
|
|
Args: |
|
train_config: A train_pb2.TrainConfig. |
|
|
|
Returns: |
|
The type of the optimizer |
|
""" |
|
return train_config.optimizer.WhichOneof("optimizer") |
|
|
|
|
|
def get_learning_rate_type(optimizer_config): |
|
"""Returns the learning rate type for training. |
|
|
|
Args: |
|
optimizer_config: An optimizer_pb2.Optimizer. |
|
|
|
Returns: |
|
The type of the learning rate. |
|
""" |
|
return optimizer_config.learning_rate.WhichOneof("learning_rate") |
|
|
|
|
|
def _is_generic_key(key): |
|
"""Determines whether the key starts with a generic config dictionary key.""" |
|
for prefix in [ |
|
"graph_rewriter_config", |
|
"model", |
|
"train_input_config", |
|
"train_config", |
|
"eval_config"]: |
|
if key.startswith(prefix + "."): |
|
return True |
|
return False |
|
|
|
|
|
def _check_and_convert_legacy_input_config_key(key): |
|
"""Checks key and converts legacy input config update to specific update. |
|
|
|
Args: |
|
key: string indicates the target of update operation. |
|
|
|
Returns: |
|
is_valid_input_config_key: A boolean indicating whether the input key is to |
|
update input config(s). |
|
key_name: 'eval_input_configs' or 'train_input_config' string if |
|
is_valid_input_config_key is true. None if is_valid_input_config_key is |
|
false. |
|
input_name: always returns None since legacy input config key never |
|
specifies the target input config. Keeping this output only to match the |
|
output form defined for input config update. |
|
field_name: the field name in input config. `key` itself if |
|
is_valid_input_config_key is false. |
|
""" |
|
key_name = None |
|
input_name = None |
|
field_name = key |
|
is_valid_input_config_key = True |
|
if field_name == "train_shuffle": |
|
key_name = "train_input_config" |
|
field_name = "shuffle" |
|
elif field_name == "eval_shuffle": |
|
key_name = "eval_input_configs" |
|
field_name = "shuffle" |
|
elif field_name == "train_input_path": |
|
key_name = "train_input_config" |
|
field_name = "input_path" |
|
elif field_name == "eval_input_path": |
|
key_name = "eval_input_configs" |
|
field_name = "input_path" |
|
elif field_name == "append_train_input_path": |
|
key_name = "train_input_config" |
|
field_name = "input_path" |
|
elif field_name == "append_eval_input_path": |
|
key_name = "eval_input_configs" |
|
field_name = "input_path" |
|
else: |
|
is_valid_input_config_key = False |
|
|
|
return is_valid_input_config_key, key_name, input_name, field_name |
|
|
|
|
|
def check_and_parse_input_config_key(configs, key): |
|
"""Checks key and returns specific fields if key is valid input config update. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
key: string indicates the target of update operation. |
|
|
|
Returns: |
|
is_valid_input_config_key: A boolean indicate whether the input key is to |
|
update input config(s). |
|
key_name: 'eval_input_configs' or 'train_input_config' string if |
|
is_valid_input_config_key is true. None if is_valid_input_config_key is |
|
false. |
|
input_name: the name of the input config to be updated. None if |
|
is_valid_input_config_key is false. |
|
field_name: the field name in input config. `key` itself if |
|
is_valid_input_config_key is false. |
|
|
|
Raises: |
|
ValueError: when the input key format doesn't match any known formats. |
|
ValueError: if key_name doesn't match 'eval_input_configs' or |
|
'train_input_config'. |
|
ValueError: if input_name doesn't match any name in train or eval input |
|
configs. |
|
ValueError: if field_name doesn't match any supported fields. |
|
""" |
|
key_name = None |
|
input_name = None |
|
field_name = None |
|
fields = key.split(":") |
|
if len(fields) == 1: |
|
field_name = key |
|
return _check_and_convert_legacy_input_config_key(key) |
|
elif len(fields) == 3: |
|
key_name = fields[0] |
|
input_name = fields[1] |
|
field_name = fields[2] |
|
else: |
|
raise ValueError("Invalid key format when overriding configs.") |
|
|
|
|
|
if key_name not in ["eval_input_configs", "train_input_config"]: |
|
raise ValueError("Invalid key_name when overriding input config.") |
|
|
|
|
|
|
|
|
|
if isinstance(configs[key_name], input_reader_pb2.InputReader): |
|
is_valid_input_name = configs[key_name].name == input_name |
|
else: |
|
is_valid_input_name = input_name in [ |
|
eval_input_config.name for eval_input_config in configs[key_name] |
|
] |
|
if not is_valid_input_name: |
|
raise ValueError("Invalid input_name when overriding input config.") |
|
|
|
|
|
if field_name not in [ |
|
"input_path", "label_map_path", "shuffle", "mask_type", |
|
"sample_1_of_n_examples" |
|
]: |
|
raise ValueError("Invalid field_name when overriding input config.") |
|
|
|
return True, key_name, input_name, field_name |
|
|
|
|
|
def merge_external_params_with_configs(configs, hparams=None, kwargs_dict=None): |
|
"""Updates `configs` dictionary based on supplied parameters. |
|
|
|
This utility is for modifying specific fields in the object detection configs. |
|
Say that one would like to experiment with different learning rates, momentum |
|
values, or batch sizes. Rather than creating a new config text file for each |
|
experiment, one can use a single base config file, and update particular |
|
values. |
|
|
|
There are two types of field overrides: |
|
1. Strategy-based overrides, which update multiple relevant configuration |
|
options. For example, updating `learning_rate` will update both the warmup and |
|
final learning rates. |
|
In this case key can be one of the following formats: |
|
1. legacy update: single string that indicates the attribute to be |
|
updated. E.g. 'label_map_path', 'eval_input_path', 'shuffle'. |
|
Note that when updating fields (e.g. eval_input_path, eval_shuffle) in |
|
eval_input_configs, the override will only be applied when |
|
eval_input_configs has exactly 1 element. |
|
2. specific update: colon separated string that indicates which field in |
|
which input_config to update. It should have 3 fields: |
|
- key_name: Name of the input config we should update, either |
|
'train_input_config' or 'eval_input_configs' |
|
- input_name: a 'name' that can be used to identify elements, especially |
|
when configs[key_name] is a repeated field. |
|
- field_name: name of the field that you want to override. |
|
For example, given configs dict as below: |
|
configs = { |
|
'model': {...} |
|
'train_config': {...} |
|
'train_input_config': {...} |
|
'eval_config': {...} |
|
'eval_input_configs': [{ name:"eval_coco", ...}, |
|
{ name:"eval_voc", ... }] |
|
} |
|
Assume we want to update the input_path of the eval_input_config |
|
whose name is 'eval_coco'. The `key` would then be: |
|
'eval_input_configs:eval_coco:input_path' |
|
2. Generic key/value, which update a specific parameter based on namespaced |
|
configuration keys. For example, |
|
`model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the |
|
hard example miner configuration for an SSD model config. Generic overrides |
|
are automatically detected based on the namespaced keys. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
hparams: A `HParams`. |
|
kwargs_dict: Extra keyword arguments that are treated the same way as |
|
attribute/value pairs in `hparams`. Note that hyperparameters with the |
|
same names will override keyword arguments. |
|
|
|
Returns: |
|
`configs` dictionary. |
|
|
|
Raises: |
|
ValueError: when the key string doesn't match any of its allowed formats. |
|
""" |
|
|
|
if kwargs_dict is None: |
|
kwargs_dict = {} |
|
if hparams: |
|
kwargs_dict.update(hparams.values()) |
|
for key, value in kwargs_dict.items(): |
|
tf.logging.info("Maybe overwriting %s: %s", key, value) |
|
|
|
if value == "" or value is None: |
|
continue |
|
|
|
elif _maybe_update_config_with_key_value(configs, key, value): |
|
continue |
|
elif _is_generic_key(key): |
|
_update_generic(configs, key, value) |
|
else: |
|
tf.logging.info("Ignoring config override key: %s", key) |
|
return configs |
|
|
|
|
|
def _maybe_update_config_with_key_value(configs, key, value): |
|
"""Checks key type and updates `configs` with the key value pair accordingly. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
key: String indicates the field(s) to be updated. |
|
value: Value used to override existing field value. |
|
|
|
Returns: |
|
A boolean value that indicates whether the override succeeds. |
|
|
|
Raises: |
|
ValueError: when the key string doesn't match any of the formats above. |
|
""" |
|
is_valid_input_config_key, key_name, input_name, field_name = ( |
|
check_and_parse_input_config_key(configs, key)) |
|
if is_valid_input_config_key: |
|
update_input_reader_config( |
|
configs, |
|
key_name=key_name, |
|
input_name=input_name, |
|
field_name=field_name, |
|
value=value) |
|
elif field_name == "learning_rate": |
|
_update_initial_learning_rate(configs, value) |
|
elif field_name == "batch_size": |
|
_update_batch_size(configs, value) |
|
elif field_name == "momentum_optimizer_value": |
|
_update_momentum_optimizer_value(configs, value) |
|
elif field_name == "classification_localization_weight_ratio": |
|
|
|
_update_classification_localization_weight_ratio(configs, value) |
|
elif field_name == "focal_loss_gamma": |
|
_update_focal_loss_gamma(configs, value) |
|
elif field_name == "focal_loss_alpha": |
|
_update_focal_loss_alpha(configs, value) |
|
elif field_name == "train_steps": |
|
_update_train_steps(configs, value) |
|
elif field_name == "label_map_path": |
|
_update_label_map_path(configs, value) |
|
elif field_name == "mask_type": |
|
_update_mask_type(configs, value) |
|
elif field_name == "sample_1_of_n_eval_examples": |
|
_update_all_eval_input_configs(configs, "sample_1_of_n_examples", value) |
|
elif field_name == "eval_num_epochs": |
|
_update_all_eval_input_configs(configs, "num_epochs", value) |
|
elif field_name == "eval_with_moving_averages": |
|
_update_use_moving_averages(configs, value) |
|
elif field_name == "retain_original_images_in_eval": |
|
_update_retain_original_images(configs["eval_config"], value) |
|
elif field_name == "use_bfloat16": |
|
_update_use_bfloat16(configs, value) |
|
else: |
|
return False |
|
return True |
|
|
|
|
|
def _update_tf_record_input_path(input_config, input_path): |
|
"""Updates input configuration to reflect a new input path. |
|
|
|
The input_config object is updated in place, and hence not returned. |
|
|
|
Args: |
|
input_config: A input_reader_pb2.InputReader. |
|
input_path: A path to data or list of paths. |
|
|
|
Raises: |
|
TypeError: if input reader type is not `tf_record_input_reader`. |
|
""" |
|
input_reader_type = input_config.WhichOneof("input_reader") |
|
if input_reader_type == "tf_record_input_reader": |
|
input_config.tf_record_input_reader.ClearField("input_path") |
|
if isinstance(input_path, list): |
|
input_config.tf_record_input_reader.input_path.extend(input_path) |
|
else: |
|
input_config.tf_record_input_reader.input_path.append(input_path) |
|
else: |
|
raise TypeError("Input reader type must be `tf_record_input_reader`.") |
|
|
|
|
|
def update_input_reader_config(configs, |
|
key_name=None, |
|
input_name=None, |
|
field_name=None, |
|
value=None, |
|
path_updater=_update_tf_record_input_path): |
|
"""Updates specified input reader config field. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
key_name: Name of the input config we should update, either |
|
'train_input_config' or 'eval_input_configs' |
|
input_name: String name used to identify input config to update with. Should |
|
be either None or value of the 'name' field in one of the input reader |
|
configs. |
|
field_name: Field name in input_reader_pb2.InputReader. |
|
value: Value used to override existing field value. |
|
path_updater: helper function used to update the input path. Only used when |
|
field_name is "input_path". |
|
|
|
Raises: |
|
ValueError: when input field_name is None. |
|
ValueError: when input_name is None and number of eval_input_readers does |
|
not equal to 1. |
|
""" |
|
if isinstance(configs[key_name], input_reader_pb2.InputReader): |
|
|
|
target_input_config = configs[key_name] |
|
if field_name == "input_path": |
|
path_updater(input_config=target_input_config, input_path=value) |
|
else: |
|
setattr(target_input_config, field_name, value) |
|
elif input_name is None and len(configs[key_name]) == 1: |
|
|
|
target_input_config = configs[key_name][0] |
|
if field_name == "input_path": |
|
path_updater(input_config=target_input_config, input_path=value) |
|
else: |
|
setattr(target_input_config, field_name, value) |
|
elif input_name is not None and len(configs[key_name]): |
|
|
|
update_count = 0 |
|
for input_config in configs[key_name]: |
|
if input_config.name == input_name: |
|
setattr(input_config, field_name, value) |
|
update_count = update_count + 1 |
|
if not update_count: |
|
raise ValueError( |
|
"Input name {} not found when overriding.".format(input_name)) |
|
elif update_count > 1: |
|
raise ValueError("Duplicate input name found when overriding.") |
|
else: |
|
key_name = "None" if key_name is None else key_name |
|
input_name = "None" if input_name is None else input_name |
|
field_name = "None" if field_name is None else field_name |
|
raise ValueError("Unknown input config overriding: " |
|
"key_name:{}, input_name:{}, field_name:{}.".format( |
|
key_name, input_name, field_name)) |
|
|
|
|
|
def _update_initial_learning_rate(configs, learning_rate): |
|
"""Updates `configs` to reflect the new initial learning rate. |
|
|
|
This function updates the initial learning rate. For learning rate schedules, |
|
all other defined learning rates in the pipeline config are scaled to maintain |
|
their same ratio with the initial learning rate. |
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
learning_rate: Initial learning rate for optimizer. |
|
|
|
Raises: |
|
TypeError: if optimizer type is not supported, or if learning rate type is |
|
not supported. |
|
""" |
|
|
|
optimizer_type = get_optimizer_type(configs["train_config"]) |
|
if optimizer_type == "rms_prop_optimizer": |
|
optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer |
|
elif optimizer_type == "momentum_optimizer": |
|
optimizer_config = configs["train_config"].optimizer.momentum_optimizer |
|
elif optimizer_type == "adam_optimizer": |
|
optimizer_config = configs["train_config"].optimizer.adam_optimizer |
|
else: |
|
raise TypeError("Optimizer %s is not supported." % optimizer_type) |
|
|
|
learning_rate_type = get_learning_rate_type(optimizer_config) |
|
if learning_rate_type == "constant_learning_rate": |
|
constant_lr = optimizer_config.learning_rate.constant_learning_rate |
|
constant_lr.learning_rate = learning_rate |
|
elif learning_rate_type == "exponential_decay_learning_rate": |
|
exponential_lr = ( |
|
optimizer_config.learning_rate.exponential_decay_learning_rate) |
|
exponential_lr.initial_learning_rate = learning_rate |
|
elif learning_rate_type == "manual_step_learning_rate": |
|
manual_lr = optimizer_config.learning_rate.manual_step_learning_rate |
|
original_learning_rate = manual_lr.initial_learning_rate |
|
learning_rate_scaling = float(learning_rate) / original_learning_rate |
|
manual_lr.initial_learning_rate = learning_rate |
|
for schedule in manual_lr.schedule: |
|
schedule.learning_rate *= learning_rate_scaling |
|
elif learning_rate_type == "cosine_decay_learning_rate": |
|
cosine_lr = optimizer_config.learning_rate.cosine_decay_learning_rate |
|
learning_rate_base = cosine_lr.learning_rate_base |
|
warmup_learning_rate = cosine_lr.warmup_learning_rate |
|
warmup_scale_factor = warmup_learning_rate / learning_rate_base |
|
cosine_lr.learning_rate_base = learning_rate |
|
cosine_lr.warmup_learning_rate = warmup_scale_factor * learning_rate |
|
else: |
|
raise TypeError("Learning rate %s is not supported." % learning_rate_type) |
|
|
|
|
|
def _update_batch_size(configs, batch_size): |
|
"""Updates `configs` to reflect the new training batch size. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
batch_size: Batch size to use for training (Ideally a power of 2). Inputs |
|
are rounded, and capped to be 1 or greater. |
|
""" |
|
configs["train_config"].batch_size = max(1, int(round(batch_size))) |
|
|
|
|
|
def _validate_message_has_field(message, field): |
|
if not message.HasField(field): |
|
raise ValueError("Expecting message to have field %s" % field) |
|
|
|
|
|
def _update_generic(configs, key, value): |
|
"""Update a pipeline configuration parameter based on a generic key/value. |
|
|
|
Args: |
|
configs: Dictionary of pipeline configuration protos. |
|
key: A string key, dot-delimited to represent the argument key. |
|
e.g. "model.ssd.train_config.batch_size" |
|
value: A value to set the argument to. The type of the value must match the |
|
type for the protocol buffer. Note that setting the wrong type will |
|
result in a TypeError. |
|
e.g. 42 |
|
|
|
Raises: |
|
ValueError if the message key does not match the existing proto fields. |
|
TypeError the value type doesn't match the protobuf field type. |
|
""" |
|
fields = key.split(".") |
|
first_field = fields.pop(0) |
|
last_field = fields.pop() |
|
message = configs[first_field] |
|
for field in fields: |
|
_validate_message_has_field(message, field) |
|
message = getattr(message, field) |
|
_validate_message_has_field(message, last_field) |
|
setattr(message, last_field, value) |
|
|
|
|
|
def _update_momentum_optimizer_value(configs, momentum): |
|
"""Updates `configs` to reflect the new momentum value. |
|
|
|
Momentum is only supported for RMSPropOptimizer and MomentumOptimizer. For any |
|
other optimizer, no changes take place. The configs dictionary is updated in |
|
place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
momentum: New momentum value. Values are clipped at 0.0 and 1.0. |
|
|
|
Raises: |
|
TypeError: If the optimizer type is not `rms_prop_optimizer` or |
|
`momentum_optimizer`. |
|
""" |
|
optimizer_type = get_optimizer_type(configs["train_config"]) |
|
if optimizer_type == "rms_prop_optimizer": |
|
optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer |
|
elif optimizer_type == "momentum_optimizer": |
|
optimizer_config = configs["train_config"].optimizer.momentum_optimizer |
|
else: |
|
raise TypeError("Optimizer type must be one of `rms_prop_optimizer` or " |
|
"`momentum_optimizer`.") |
|
|
|
optimizer_config.momentum_optimizer_value = min(max(0.0, momentum), 1.0) |
|
|
|
|
|
def _update_classification_localization_weight_ratio(configs, ratio): |
|
"""Updates the classification/localization weight loss ratio. |
|
|
|
Detection models usually define a loss weight for both classification and |
|
objectness. This function updates the weights such that the ratio between |
|
classification weight to localization weight is the ratio provided. |
|
Arbitrarily, localization weight is set to 1.0. |
|
|
|
Note that in the case of Faster R-CNN, this same ratio is applied to the first |
|
stage objectness loss weight relative to localization loss weight. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
ratio: Desired ratio of classification (and/or objectness) loss weight to |
|
localization loss weight. |
|
""" |
|
meta_architecture = configs["model"].WhichOneof("model") |
|
if meta_architecture == "faster_rcnn": |
|
model = configs["model"].faster_rcnn |
|
model.first_stage_localization_loss_weight = 1.0 |
|
model.first_stage_objectness_loss_weight = ratio |
|
model.second_stage_localization_loss_weight = 1.0 |
|
model.second_stage_classification_loss_weight = ratio |
|
if meta_architecture == "ssd": |
|
model = configs["model"].ssd |
|
model.loss.localization_weight = 1.0 |
|
model.loss.classification_weight = ratio |
|
|
|
|
|
def _get_classification_loss(model_config): |
|
"""Returns the classification loss for a model.""" |
|
meta_architecture = model_config.WhichOneof("model") |
|
if meta_architecture == "faster_rcnn": |
|
model = model_config.faster_rcnn |
|
classification_loss = model.second_stage_classification_loss |
|
elif meta_architecture == "ssd": |
|
model = model_config.ssd |
|
classification_loss = model.loss.classification_loss |
|
else: |
|
raise TypeError("Did not recognize the model architecture.") |
|
return classification_loss |
|
|
|
|
|
def _update_focal_loss_gamma(configs, gamma): |
|
"""Updates the gamma value for a sigmoid focal loss. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
gamma: Exponent term in focal loss. |
|
|
|
Raises: |
|
TypeError: If the classification loss is not `weighted_sigmoid_focal`. |
|
""" |
|
classification_loss = _get_classification_loss(configs["model"]) |
|
classification_loss_type = classification_loss.WhichOneof( |
|
"classification_loss") |
|
if classification_loss_type != "weighted_sigmoid_focal": |
|
raise TypeError("Classification loss must be `weighted_sigmoid_focal`.") |
|
classification_loss.weighted_sigmoid_focal.gamma = gamma |
|
|
|
|
|
def _update_focal_loss_alpha(configs, alpha): |
|
"""Updates the alpha value for a sigmoid focal loss. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
alpha: Class weight multiplier for sigmoid loss. |
|
|
|
Raises: |
|
TypeError: If the classification loss is not `weighted_sigmoid_focal`. |
|
""" |
|
classification_loss = _get_classification_loss(configs["model"]) |
|
classification_loss_type = classification_loss.WhichOneof( |
|
"classification_loss") |
|
if classification_loss_type != "weighted_sigmoid_focal": |
|
raise TypeError("Classification loss must be `weighted_sigmoid_focal`.") |
|
classification_loss.weighted_sigmoid_focal.alpha = alpha |
|
|
|
|
|
def _update_train_steps(configs, train_steps): |
|
"""Updates `configs` to reflect new number of training steps.""" |
|
configs["train_config"].num_steps = int(train_steps) |
|
|
|
|
|
def _update_eval_steps(configs, eval_steps): |
|
"""Updates `configs` to reflect new number of eval steps per evaluation.""" |
|
configs["eval_config"].num_examples = int(eval_steps) |
|
|
|
|
|
def _update_all_eval_input_configs(configs, field, value): |
|
"""Updates the content of `field` with `value` for all eval input configs.""" |
|
for eval_input_config in configs["eval_input_configs"]: |
|
setattr(eval_input_config, field, value) |
|
|
|
|
|
def _update_label_map_path(configs, label_map_path): |
|
"""Updates the label map path for both train and eval input readers. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
label_map_path: New path to `StringIntLabelMap` pbtxt file. |
|
""" |
|
configs["train_input_config"].label_map_path = label_map_path |
|
_update_all_eval_input_configs(configs, "label_map_path", label_map_path) |
|
|
|
|
|
def _update_mask_type(configs, mask_type): |
|
"""Updates the mask type for both train and eval input readers. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
mask_type: A string name representing a value of |
|
input_reader_pb2.InstanceMaskType |
|
""" |
|
configs["train_input_config"].mask_type = mask_type |
|
_update_all_eval_input_configs(configs, "mask_type", mask_type) |
|
|
|
|
|
def _update_use_moving_averages(configs, use_moving_averages): |
|
"""Updates the eval config option to use or not use moving averages. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
use_moving_averages: Boolean indicating whether moving average variables |
|
should be loaded during evaluation. |
|
""" |
|
configs["eval_config"].use_moving_averages = use_moving_averages |
|
|
|
|
|
def _update_retain_original_images(eval_config, retain_original_images): |
|
"""Updates eval config with option to retain original images. |
|
|
|
The eval_config object is updated in place, and hence not returned. |
|
|
|
Args: |
|
eval_config: A eval_pb2.EvalConfig. |
|
retain_original_images: Boolean indicating whether to retain original images |
|
in eval mode. |
|
""" |
|
eval_config.retain_original_images = retain_original_images |
|
|
|
|
|
def _update_use_bfloat16(configs, use_bfloat16): |
|
"""Updates `configs` to reflect the new setup on whether to use bfloat16. |
|
|
|
The configs dictionary is updated in place, and hence not returned. |
|
|
|
Args: |
|
configs: Dictionary of configuration objects. See outputs from |
|
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). |
|
use_bfloat16: A bool, indicating whether to use bfloat16 for training. |
|
""" |
|
configs["train_config"].use_bfloat16 = use_bfloat16 |
|
|