# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for vectorspace_fns.""" from absl.testing import absltest from absl.testing import parameterized import numpy as np from tracr.craft import bases from tracr.craft import tests_common from tracr.craft import vectorspace_fns as vs_fns class LinearTest(tests_common.VectorFnTestCase): def test_identity_from_matrix(self): vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) f = vs_fns.Linear(vs, vs, np.eye(3)) for v in vs.basis_vectors(): self.assertEqual(f(v), v) def test_identity_from_action(self): vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction) for v in vs.basis_vectors(): self.assertEqual(f(v), v) def test_nonidentiy(self): vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) a = vs.vector_from_basis_direction(bases.BasisDirection("a")) b = vs.vector_from_basis_direction(bases.BasisDirection("b")) f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]])) self.assertEqual( f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7]))) self.assertEqual( f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1]))) def test_different_vector_spaces(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) a, b = vs1.basis_vectors() c, d = vs2.basis_vectors() f = vs_fns.Linear(vs1, vs2, np.eye(2)) self.assertEqual(f(a), c) self.assertEqual(f(b), d) def test_combining_linear_functions_with_different_input(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) vs = bases.direct_sum(vs1, vs2) a = vs.vector_from_basis_direction(bases.BasisDirection("a")) b = vs.vector_from_basis_direction(bases.BasisDirection("b")) c = vs.vector_from_basis_direction(bases.BasisDirection("c")) d = vs.vector_from_basis_direction(bases.BasisDirection("d")) f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]])) f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]])) f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) self.assertEqual( f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0]))) self.assertEqual( f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0]))) self.assertEqual( f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0]))) self.assertEqual( f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0]))) def test_combining_linear_functions_with_same_input(self): vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) a = vs.vector_from_basis_direction(bases.BasisDirection("a")) b = vs.vector_from_basis_direction(bases.BasisDirection("b")) f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]])) f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]])) f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) self.assertEqual( f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1]))) self.assertEqual( f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0]))) self.assertEqual(f3(a), f1(a) + f2(a)) self.assertEqual(f3(b), f1(b) + f2(b)) class ProjectionTest(tests_common.VectorFnTestCase): def test_projection_to_larger_space(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) a1, b1 = vs1.basis_vectors() a2, b2, _, _ = vs2.basis_vectors() f = vs_fns.project(vs1, vs2) self.assertEqual(f(a1), a2) self.assertEqual(f(b1), b2) def test_projection_to_smaller_space(self): vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) a1, b1, c1, d1 = vs1.basis_vectors() a2, b2 = vs2.basis_vectors() f = vs_fns.project(vs1, vs2) self.assertEqual(f(a1), a2) self.assertEqual(f(b1), b2) self.assertEqual(f(c1), vs2.null_vector()) self.assertEqual(f(d1), vs2.null_vector()) class ScalarBilinearTest(parameterized.TestCase): def test_identity_matrix(self): vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) a, b = vs.basis_vectors() f = vs_fns.ScalarBilinear(vs, vs, np.eye(2)) self.assertEqual(f(a, a), 1) self.assertEqual(f(a, b), 0) self.assertEqual(f(b, a), 0) self.assertEqual(f(b, b), 1) def test_identity_from_action(self): vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) a, b = vs.basis_vectors() f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y)) self.assertEqual(f(a, a), 1) self.assertEqual(f(a, b), 0) self.assertEqual(f(b, a), 0) self.assertEqual(f(b, b), 1) def test_non_identity(self): vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) a, b = vs.basis_vectors() f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x.name == "a")) self.assertEqual(f(a, a), 1) self.assertEqual(f(a, b), 1) self.assertEqual(f(b, a), 0) self.assertEqual(f(b, b), 0) if __name__ == "__main__": absltest.main()