Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A function to build an object detection matcher from configuration.""" | |
from object_detection.matchers import argmax_matcher | |
from object_detection.protos import matcher_pb2 | |
from object_detection.utils import tf_version | |
if tf_version.is_tf1(): | |
from object_detection.matchers import bipartite_matcher # pylint: disable=g-import-not-at-top | |
def build(matcher_config): | |
"""Builds a matcher object based on the matcher config. | |
Args: | |
matcher_config: A matcher.proto object containing the config for the desired | |
Matcher. | |
Returns: | |
Matcher based on the config. | |
Raises: | |
ValueError: On empty matcher proto. | |
""" | |
if not isinstance(matcher_config, matcher_pb2.Matcher): | |
raise ValueError('matcher_config not of type matcher_pb2.Matcher.') | |
if matcher_config.WhichOneof('matcher_oneof') == 'argmax_matcher': | |
matcher = matcher_config.argmax_matcher | |
matched_threshold = unmatched_threshold = None | |
if not matcher.ignore_thresholds: | |
matched_threshold = matcher.matched_threshold | |
unmatched_threshold = matcher.unmatched_threshold | |
return argmax_matcher.ArgMaxMatcher( | |
matched_threshold=matched_threshold, | |
unmatched_threshold=unmatched_threshold, | |
negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched, | |
force_match_for_each_row=matcher.force_match_for_each_row, | |
use_matmul_gather=matcher.use_matmul_gather) | |
if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher': | |
if tf_version.is_tf2(): | |
raise ValueError('bipartite_matcher is not supported in TF 2.X') | |
matcher = matcher_config.bipartite_matcher | |
return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather) | |
raise ValueError('Empty matcher.') | |