Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for bases.""" | |
from absl.testing import absltest | |
import numpy as np | |
from tracr.craft import bases | |
from tracr.craft import tests_common | |
class VectorInBasisTest(tests_common.VectorFnTestCase): | |
def test_shape_mismatch_raises_value_error(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
regex = (r"^.*Last dimension of magnitudes must be the same as number of " | |
r"basis directions.*$") | |
with self.assertRaisesRegex(ValueError, regex): | |
bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) | |
with self.assertRaisesRegex(ValueError, regex): | |
bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) | |
def test_equal(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) | |
v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) | |
v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) | |
self.assertEqual(v1, v2) | |
self.assertEqual(v2, v1) | |
v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) | |
v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) | |
self.assertEqual(v3, v4) | |
self.assertEqual(v4, v3) | |
v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) | |
v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1])) | |
self.assertNotEqual(v5, v6) | |
self.assertNotEqual(v6, v5) | |
v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) | |
v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]])) | |
self.assertNotEqual(v7, v8) | |
self.assertNotEqual(v8, v7) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"]) | |
v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) | |
v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4])) | |
self.assertNotEqual(v9, v10) | |
self.assertNotEqual(v10, v9) | |
def test_dunders(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) | |
v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2])) | |
three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3])) | |
five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5])) | |
v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10])) | |
self.assertEqual(5 * v, v_times_5) | |
self.assertEqual(v * 5, v_times_5) | |
self.assertEqual(5.0 * v, v_times_5) | |
self.assertEqual(v * 5.0, v_times_5) | |
v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1])) | |
self.assertEqual(v / 2, v_by_2) | |
self.assertEqual(v / 2.0, v_by_2) | |
self.assertEqual(1 / 2 * v, v_by_2) | |
v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5])) | |
self.assertEqual(v + three, v_plus_3) | |
self.assertEqual(three + v, v_plus_3) | |
v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3])) | |
self.assertEqual(v - five, v_minus_5) | |
minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2])) | |
self.assertEqual(-v, minus_v) | |
class ProjectionTest(tests_common.VectorFnTestCase): | |
def test_direct_sum_produces_expected_result(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) | |
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"]) | |
self.assertEqual(bases.direct_sum(vs1, vs2), vs3) | |
def test_join_vector_spaces_produces_expected_result(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) | |
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) | |
self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"]) | |
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) | |
self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) | |
def test_compare_vectors_with_differently_ordered_basis_vectors(self): | |
basis1 = ["a", "b", "c", "d"] | |
basis1 = [bases.BasisDirection(x) for x in basis1] | |
basis2 = ["b", "d", "a", "c"] | |
basis2 = [bases.BasisDirection(x) for x in basis2] | |
vs1 = bases.VectorSpaceWithBasis(basis1) | |
vs2 = bases.VectorSpaceWithBasis(basis2) | |
v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4])) | |
v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3])) | |
self.assertEqual(v1, v2) | |
self.assertEqual(v1 - v2, vs1.null_vector()) | |
self.assertEqual(v1 - v2, vs2.null_vector()) | |
self.assertEqual(v1 + v2, 2 * v2) | |
self.assertIn(v1, vs1) | |
self.assertIn(v1, vs2) | |
self.assertIn(v2, vs1) | |
self.assertIn(v2, vs2) | |
def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self): | |
basis1 = ["a", "b", "c", "d"] | |
basis1 = [bases.BasisDirection(x) for x in basis1] | |
basis2 = ["b", "d", "a", "c"] | |
basis2 = [bases.BasisDirection(x) for x in basis2] | |
vs1 = bases.VectorSpaceWithBasis(basis1) | |
vs2 = bases.VectorSpaceWithBasis(basis2) | |
v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]])) | |
v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]])) | |
null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()]) | |
self.assertEqual(v1, v2) | |
self.assertEqual(v1 - v2, null_vec) | |
self.assertEqual(v1 + v2, 2 * v2) | |
self.assertIn(v1, vs1) | |
self.assertIn(v1, vs2) | |
self.assertIn(v2, vs1) | |
self.assertIn(v2, vs2) | |
def test_projection_to_larger_space(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) | |
a1, b1 = vs1.basis_vectors() | |
a2, b2, _, _ = vs2.basis_vectors() | |
self.assertEqual(a1.project(vs2), a2) | |
self.assertEqual(b1.project(vs2), b2) | |
def test_projection_to_smaller_space(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
a1, b1, c1, d1 = vs1.basis_vectors() | |
a2, b2 = vs2.basis_vectors() | |
self.assertEqual(a1.project(vs2), a2) | |
self.assertEqual(b1.project(vs2), b2) | |
self.assertEqual(c1.project(vs2), vs2.null_vector()) | |
self.assertEqual(d1.project(vs2), vs2.null_vector()) | |
if __name__ == "__main__": | |
absltest.main() | |