# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for bases.""" from absl.testing import absltest import numpy as np from tracr.craft import bases from tracr.craft import tests_common class VectorInBasisTest(tests_common.VectorFnTestCase): def test_shape_mismatch_raises_value_error(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) regex = (r"^.*Last dimension of magnitudes must be the same as number of " r"basis directions.*$") with self.assertRaisesRegex(ValueError, regex): bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) with self.assertRaisesRegex(ValueError, regex): bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) def test_equal(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) self.assertEqual(v1, v2) self.assertEqual(v2, v1) v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) self.assertEqual(v3, v4) self.assertEqual(v4, v3) v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1])) self.assertNotEqual(v5, v6) self.assertNotEqual(v6, v5) v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]])) self.assertNotEqual(v7, v8) self.assertNotEqual(v8, v7) vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"]) v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4])) self.assertNotEqual(v9, v10) self.assertNotEqual(v10, v9) def test_dunders(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2])) three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3])) five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5])) v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10])) self.assertEqual(5 * v, v_times_5) self.assertEqual(v * 5, v_times_5) self.assertEqual(5.0 * v, v_times_5) self.assertEqual(v * 5.0, v_times_5) v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1])) self.assertEqual(v / 2, v_by_2) self.assertEqual(v / 2.0, v_by_2) self.assertEqual(1 / 2 * v, v_by_2) v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5])) self.assertEqual(v + three, v_plus_3) self.assertEqual(three + v, v_plus_3) v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3])) self.assertEqual(v - five, v_minus_5) minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2])) self.assertEqual(-v, minus_v) class ProjectionTest(tests_common.VectorFnTestCase): def test_direct_sum_produces_expected_result(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"]) self.assertEqual(bases.direct_sum(vs1, vs2), vs3) def test_join_vector_spaces_produces_expected_result(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"]) vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) def test_compare_vectors_with_differently_ordered_basis_vectors(self): basis1 = ["a", "b", "c", "d"] basis1 = [bases.BasisDirection(x) for x in basis1] basis2 = ["b", "d", "a", "c"] basis2 = [bases.BasisDirection(x) for x in basis2] vs1 = bases.VectorSpaceWithBasis(basis1) vs2 = bases.VectorSpaceWithBasis(basis2) v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4])) v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3])) self.assertEqual(v1, v2) self.assertEqual(v1 - v2, vs1.null_vector()) self.assertEqual(v1 - v2, vs2.null_vector()) self.assertEqual(v1 + v2, 2 * v2) self.assertIn(v1, vs1) self.assertIn(v1, vs2) self.assertIn(v2, vs1) self.assertIn(v2, vs2) def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self): basis1 = ["a", "b", "c", "d"] basis1 = [bases.BasisDirection(x) for x in basis1] basis2 = ["b", "d", "a", "c"] basis2 = [bases.BasisDirection(x) for x in basis2] vs1 = bases.VectorSpaceWithBasis(basis1) vs2 = bases.VectorSpaceWithBasis(basis2) v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]])) v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]])) null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()]) self.assertEqual(v1, v2) self.assertEqual(v1 - v2, null_vec) self.assertEqual(v1 + v2, 2 * v2) self.assertIn(v1, vs1) self.assertIn(v1, vs2) self.assertIn(v2, vs1) self.assertIn(v2, vs2) def test_projection_to_larger_space(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) a1, b1 = vs1.basis_vectors() a2, b2, _, _ = vs2.basis_vectors() self.assertEqual(a1.project(vs2), a2) self.assertEqual(b1.project(vs2), b2) def test_projection_to_smaller_space(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) a1, b1, c1, d1 = vs1.basis_vectors() a2, b2 = vs2.basis_vectors() self.assertEqual(a1.project(vs2), a2) self.assertEqual(b1.project(vs2), b2) self.assertEqual(c1.project(vs2), vs2.null_vector()) self.assertEqual(d1.project(vs2), vs2.null_vector()) if __name__ == "__main__": absltest.main()