DR-App / object_detection /matchers /argmax_matcher_test.py
pat229988's picture
Upload 653 files
9a393e2
raw
history blame
No virus
10.7 kB
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.matchers.argmax_matcher."""
import numpy as np
import tensorflow as tf
from object_detection.matchers import argmax_matcher
from object_detection.utils import test_case
class ArgMaxMatcherTest(test_case.TestCase):
def test_return_correct_matches_with_default_thresholds(self):
def graph_fn(similarity_matrix):
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=None)
match = matcher.match(similarity_matrix)
matched_cols = match.matched_column_indicator()
unmatched_cols = match.unmatched_column_indicator()
match_results = match.match_results
return (matched_cols, unmatched_cols, match_results)
similarity = np.array([[1., 1, 1, 3, 1],
[2, -1, 2, 0, 4],
[3, 0, -1, 0, 0]], dtype=np.float32)
expected_matched_rows = np.array([2, 0, 1, 0, 1])
(res_matched_cols, res_unmatched_cols,
res_match_results) = self.execute(graph_fn, [similarity])
self.assertAllEqual(res_match_results[res_matched_cols],
expected_matched_rows)
self.assertAllEqual(np.nonzero(res_matched_cols)[0], [0, 1, 2, 3, 4])
self.assertFalse(np.all(res_unmatched_cols))
def test_return_correct_matches_with_empty_rows(self):
def graph_fn(similarity_matrix):
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=None)
match = matcher.match(similarity_matrix)
return match.unmatched_column_indicator()
similarity = 0.2 * np.ones([0, 5], dtype=np.float32)
res_unmatched_cols = self.execute(graph_fn, [similarity])
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0], np.arange(5))
def test_return_correct_matches_with_matched_threshold(self):
def graph_fn(similarity):
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.)
match = matcher.match(similarity)
matched_cols = match.matched_column_indicator()
unmatched_cols = match.unmatched_column_indicator()
match_results = match.match_results
return (matched_cols, unmatched_cols, match_results)
similarity = np.array([[1, 1, 1, 3, 1],
[2, -1, 2, 0, 4],
[3, 0, -1, 0, 0]], dtype=np.float32)
expected_matched_cols = np.array([0, 3, 4])
expected_matched_rows = np.array([2, 0, 1])
expected_unmatched_cols = np.array([1, 2])
(res_matched_cols, res_unmatched_cols,
match_results) = self.execute(graph_fn, [similarity])
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
expected_unmatched_cols)
def test_return_correct_matches_with_matched_and_unmatched_threshold(self):
def graph_fn(similarity):
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
unmatched_threshold=2.)
match = matcher.match(similarity)
matched_cols = match.matched_column_indicator()
unmatched_cols = match.unmatched_column_indicator()
match_results = match.match_results
return (matched_cols, unmatched_cols, match_results)
similarity = np.array([[1, 1, 1, 3, 1],
[2, -1, 2, 0, 4],
[3, 0, -1, 0, 0]], dtype=np.float32)
expected_matched_cols = np.array([0, 3, 4])
expected_matched_rows = np.array([2, 0, 1])
expected_unmatched_cols = np.array([1]) # col 2 has too high maximum val
(res_matched_cols, res_unmatched_cols,
match_results) = self.execute(graph_fn, [similarity])
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
expected_unmatched_cols)
def test_return_correct_matches_negatives_lower_than_unmatched_false(self):
def graph_fn(similarity):
matcher = argmax_matcher.ArgMaxMatcher(
matched_threshold=3.,
unmatched_threshold=2.,
negatives_lower_than_unmatched=False)
match = matcher.match(similarity)
matched_cols = match.matched_column_indicator()
unmatched_cols = match.unmatched_column_indicator()
match_results = match.match_results
return (matched_cols, unmatched_cols, match_results)
similarity = np.array([[1, 1, 1, 3, 1],
[2, -1, 2, 0, 4],
[3, 0, -1, 0, 0]], dtype=np.float32)
expected_matched_cols = np.array([0, 3, 4])
expected_matched_rows = np.array([2, 0, 1])
expected_unmatched_cols = np.array([2]) # col 1 has too low maximum val
(res_matched_cols, res_unmatched_cols,
match_results) = self.execute(graph_fn, [similarity])
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
expected_unmatched_cols)
def test_return_correct_matches_unmatched_row_not_using_force_match(self):
def graph_fn(similarity):
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
unmatched_threshold=2.)
match = matcher.match(similarity)
matched_cols = match.matched_column_indicator()
unmatched_cols = match.unmatched_column_indicator()
match_results = match.match_results
return (matched_cols, unmatched_cols, match_results)
similarity = np.array([[1, 1, 1, 3, 1],
[-1, 0, -2, -2, -1],
[3, 0, -1, 2, 0]], dtype=np.float32)
expected_matched_cols = np.array([0, 3])
expected_matched_rows = np.array([2, 0])
expected_unmatched_cols = np.array([1, 2, 4])
(res_matched_cols, res_unmatched_cols,
match_results) = self.execute(graph_fn, [similarity])
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
expected_unmatched_cols)
def test_return_correct_matches_unmatched_row_while_using_force_match(self):
def graph_fn(similarity):
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
unmatched_threshold=2.,
force_match_for_each_row=True)
match = matcher.match(similarity)
matched_cols = match.matched_column_indicator()
unmatched_cols = match.unmatched_column_indicator()
match_results = match.match_results
return (matched_cols, unmatched_cols, match_results)
similarity = np.array([[1, 1, 1, 3, 1],
[-1, 0, -2, -2, -1],
[3, 0, -1, 2, 0]], dtype=np.float32)
expected_matched_cols = np.array([0, 1, 3])
expected_matched_rows = np.array([2, 1, 0])
expected_unmatched_cols = np.array([2, 4]) # col 2 has too high max val
(res_matched_cols, res_unmatched_cols,
match_results) = self.execute(graph_fn, [similarity])
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
expected_unmatched_cols)
def test_return_correct_matches_using_force_match_padded_groundtruth(self):
def graph_fn(similarity, valid_rows):
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
unmatched_threshold=2.,
force_match_for_each_row=True)
match = matcher.match(similarity, valid_rows)
matched_cols = match.matched_column_indicator()
unmatched_cols = match.unmatched_column_indicator()
match_results = match.match_results
return (matched_cols, unmatched_cols, match_results)
similarity = np.array([[1, 1, 1, 3, 1],
[-1, 0, -2, -2, -1],
[0, 0, 0, 0, 0],
[3, 0, -1, 2, 0],
[0, 0, 0, 0, 0]], dtype=np.float32)
valid_rows = np.array([True, True, False, True, False])
expected_matched_cols = np.array([0, 1, 3])
expected_matched_rows = np.array([3, 1, 0])
expected_unmatched_cols = np.array([2, 4]) # col 2 has too high max val
(res_matched_cols, res_unmatched_cols,
match_results) = self.execute(graph_fn, [similarity, valid_rows])
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
expected_unmatched_cols)
def test_valid_arguments_corner_case(self):
argmax_matcher.ArgMaxMatcher(matched_threshold=1,
unmatched_threshold=1)
def test_invalid_arguments_corner_case_negatives_lower_than_thres_false(self):
with self.assertRaises(ValueError):
argmax_matcher.ArgMaxMatcher(matched_threshold=1,
unmatched_threshold=1,
negatives_lower_than_unmatched=False)
def test_invalid_arguments_no_matched_threshold(self):
with self.assertRaises(ValueError):
argmax_matcher.ArgMaxMatcher(matched_threshold=None,
unmatched_threshold=4)
def test_invalid_arguments_unmatched_thres_larger_than_matched_thres(self):
with self.assertRaises(ValueError):
argmax_matcher.ArgMaxMatcher(matched_threshold=1,
unmatched_threshold=2)
if __name__ == '__main__':
tf.test.main()