Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A function to build an object detection matcher from configuration.""" | |
| from object_detection.matchers import argmax_matcher | |
| from object_detection.protos import matcher_pb2 | |
| from object_detection.utils import tf_version | |
| if tf_version.is_tf1(): | |
| from object_detection.matchers import bipartite_matcher # pylint: disable=g-import-not-at-top | |
| def build(matcher_config): | |
| """Builds a matcher object based on the matcher config. | |
| Args: | |
| matcher_config: A matcher.proto object containing the config for the desired | |
| Matcher. | |
| Returns: | |
| Matcher based on the config. | |
| Raises: | |
| ValueError: On empty matcher proto. | |
| """ | |
| if not isinstance(matcher_config, matcher_pb2.Matcher): | |
| raise ValueError('matcher_config not of type matcher_pb2.Matcher.') | |
| if matcher_config.WhichOneof('matcher_oneof') == 'argmax_matcher': | |
| matcher = matcher_config.argmax_matcher | |
| matched_threshold = unmatched_threshold = None | |
| if not matcher.ignore_thresholds: | |
| matched_threshold = matcher.matched_threshold | |
| unmatched_threshold = matcher.unmatched_threshold | |
| return argmax_matcher.ArgMaxMatcher( | |
| matched_threshold=matched_threshold, | |
| unmatched_threshold=unmatched_threshold, | |
| negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched, | |
| force_match_for_each_row=matcher.force_match_for_each_row, | |
| use_matmul_gather=matcher.use_matmul_gather) | |
| if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher': | |
| if tf_version.is_tf2(): | |
| raise ValueError('bipartite_matcher is not supported in TF 2.X') | |
| matcher = matcher_config.bipartite_matcher | |
| return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather) | |
| raise ValueError('Empty matcher.') | |