# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions on vector spaces.""" import abc import dataclasses from typing import Callable, Sequence import numpy as np from tracr.craft import bases VectorSpaceWithBasis = bases.VectorSpaceWithBasis VectorInBasis = bases.VectorInBasis BasisDirection = bases.BasisDirection class VectorFunction(abc.ABC): """A function that acts on vectors.""" input_space: VectorSpaceWithBasis output_space: VectorSpaceWithBasis @abc.abstractmethod def __call__(self, x: VectorInBasis) -> VectorInBasis: """Evaluates the function.""" class Linear(VectorFunction): """A linear function.""" def __init__( self, input_space: VectorSpaceWithBasis, output_space: VectorSpaceWithBasis, matrix: np.ndarray, ): """Initialises. Args: input_space: The input vector space. output_space: The output vector space. matrix: a [input, output] matrix acting in a (sorted) basis. """ self.input_space = input_space self.output_space = output_space self.matrix = matrix def __post_init__(self) -> None: output_size, input_size = self.matrix.shape assert input_size == self.input_space.num_dims assert output_size == self.output_space.num_dims def __call__(self, x: VectorInBasis) -> VectorInBasis: if x not in self.input_space: raise TypeError(f"x={x} not in self.input_space={self.input_space}.") return VectorInBasis( basis_directions=sorted(self.output_space.basis), magnitudes=x.magnitudes @ self.matrix, ) @classmethod def from_action( cls, input_space: VectorSpaceWithBasis, output_space: VectorSpaceWithBasis, action: Callable[[BasisDirection], VectorInBasis], ) -> "Linear": """from_action(i, o)(action) creates a Linear.""" matrix = np.zeros((input_space.num_dims, output_space.num_dims)) for i, direction in enumerate(input_space.basis): out_vector = action(direction) if out_vector not in output_space: raise TypeError(f"image of {direction} from input_space={input_space} " f"is not in output_space={output_space}") matrix[i, :] = out_vector.magnitudes return Linear(input_space, output_space, matrix) @classmethod def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear": """Combines multiple parallel linear functions into a single one.""" joint_input_space = bases.join_vector_spaces( *[fn.input_space for fn in fns]) joint_output_space = bases.join_vector_spaces( *[fn.output_space for fn in fns]) def action(x: bases.BasisDirection) -> bases.VectorInBasis: out = joint_output_space.null_vector() for fn in fns: if x in fn.input_space: x_vec = fn.input_space.vector_from_basis_direction(x) out += fn(x_vec).project(joint_output_space) return out return cls.from_action(joint_input_space, joint_output_space, action) def project( from_space: VectorSpaceWithBasis, to_space: VectorSpaceWithBasis, ) -> Linear: """Creates a projection.""" def action(direction: bases.BasisDirection) -> VectorInBasis: if direction in to_space: return to_space.vector_from_basis_direction(direction) else: return to_space.null_vector() return Linear.from_action(from_space, to_space, action=action) @dataclasses.dataclass class ScalarBilinear: """A scalar-valued bilinear operator.""" left_space: VectorSpaceWithBasis right_space: VectorSpaceWithBasis matrix: np.ndarray def __post_init__(self): """Ensure matrix acts in sorted bases and typecheck sizes.""" left_size, right_size = self.matrix.shape assert left_size == self.left_space.num_dims assert right_size == self.right_space.num_dims def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float: """Describes the action of the operator on vectors.""" if x not in self.left_space: raise TypeError(f"x={x} not in self.left_space={self.left_space}.") if y not in self.right_space: raise TypeError(f"y={y} not in self.right_space={self.right_space}.") return (x.magnitudes.T @ self.matrix @ y.magnitudes).item() @classmethod def from_action( cls, left_space: VectorSpaceWithBasis, right_space: VectorSpaceWithBasis, action: Callable[[BasisDirection, BasisDirection], float], ) -> "ScalarBilinear": """from_action(l, r)(action) creates a ScalarBilinear.""" matrix = np.zeros((left_space.num_dims, right_space.num_dims)) for i, left_direction in enumerate(left_space.basis): for j, right_direction in enumerate(right_space.basis): matrix[i, j] = action(left_direction, right_direction) return ScalarBilinear(left_space, right_space, matrix)