RASP-Synthesis / tracr /craft /bases_test.py
Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
raw
history blame
6.88 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for bases."""
from absl.testing import absltest
import numpy as np
from tracr.craft import bases
from tracr.craft import tests_common
class VectorInBasisTest(tests_common.VectorFnTestCase):
def test_shape_mismatch_raises_value_error(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
regex = (r"^.*Last dimension of magnitudes must be the same as number of "
r"basis directions.*$")
with self.assertRaisesRegex(ValueError, regex):
bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
with self.assertRaisesRegex(ValueError, regex):
bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
def test_equal(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
self.assertEqual(v1, v2)
self.assertEqual(v2, v1)
v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
self.assertEqual(v3, v4)
self.assertEqual(v4, v3)
v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1]))
self.assertNotEqual(v5, v6)
self.assertNotEqual(v6, v5)
v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]]))
self.assertNotEqual(v7, v8)
self.assertNotEqual(v8, v7)
vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"])
v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4]))
self.assertNotEqual(v9, v10)
self.assertNotEqual(v10, v9)
def test_dunders(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2]))
three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3]))
five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5]))
v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10]))
self.assertEqual(5 * v, v_times_5)
self.assertEqual(v * 5, v_times_5)
self.assertEqual(5.0 * v, v_times_5)
self.assertEqual(v * 5.0, v_times_5)
v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1]))
self.assertEqual(v / 2, v_by_2)
self.assertEqual(v / 2.0, v_by_2)
self.assertEqual(1 / 2 * v, v_by_2)
v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5]))
self.assertEqual(v + three, v_plus_3)
self.assertEqual(three + v, v_plus_3)
v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3]))
self.assertEqual(v - five, v_minus_5)
minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2]))
self.assertEqual(-v, minus_v)
class ProjectionTest(tests_common.VectorFnTestCase):
def test_direct_sum_produces_expected_result(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"])
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"])
self.assertEqual(bases.direct_sum(vs1, vs2), vs3)
def test_join_vector_spaces_produces_expected_result(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"])
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3)
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"])
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3)
def test_compare_vectors_with_differently_ordered_basis_vectors(self):
basis1 = ["a", "b", "c", "d"]
basis1 = [bases.BasisDirection(x) for x in basis1]
basis2 = ["b", "d", "a", "c"]
basis2 = [bases.BasisDirection(x) for x in basis2]
vs1 = bases.VectorSpaceWithBasis(basis1)
vs2 = bases.VectorSpaceWithBasis(basis2)
v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4]))
v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3]))
self.assertEqual(v1, v2)
self.assertEqual(v1 - v2, vs1.null_vector())
self.assertEqual(v1 - v2, vs2.null_vector())
self.assertEqual(v1 + v2, 2 * v2)
self.assertIn(v1, vs1)
self.assertIn(v1, vs2)
self.assertIn(v2, vs1)
self.assertIn(v2, vs2)
def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self):
basis1 = ["a", "b", "c", "d"]
basis1 = [bases.BasisDirection(x) for x in basis1]
basis2 = ["b", "d", "a", "c"]
basis2 = [bases.BasisDirection(x) for x in basis2]
vs1 = bases.VectorSpaceWithBasis(basis1)
vs2 = bases.VectorSpaceWithBasis(basis2)
v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]]))
v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]]))
null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()])
self.assertEqual(v1, v2)
self.assertEqual(v1 - v2, null_vec)
self.assertEqual(v1 + v2, 2 * v2)
self.assertIn(v1, vs1)
self.assertIn(v1, vs2)
self.assertIn(v2, vs1)
self.assertIn(v2, vs2)
def test_projection_to_larger_space(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
a1, b1 = vs1.basis_vectors()
a2, b2, _, _ = vs2.basis_vectors()
self.assertEqual(a1.project(vs2), a2)
self.assertEqual(b1.project(vs2), b2)
def test_projection_to_smaller_space(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
a1, b1, c1, d1 = vs1.basis_vectors()
a2, b2 = vs2.basis_vectors()
self.assertEqual(a1.project(vs2), a2)
self.assertEqual(b1.project(vs2), b2)
self.assertEqual(c1.project(vs2), vs2.null_vector())
self.assertEqual(d1.project(vs2), vs2.null_vector())
if __name__ == "__main__":
absltest.main()