RASP-Synthesis / tracr /craft /vectorspace_fns.py
NeelNanda's picture
Made compatible with Python 3.8
c46567d
raw
history blame
No virus
5.46 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions on vector spaces."""
import abc
import dataclasses
from typing import Callable, Sequence
import numpy as np
from tracr.craft import bases
VectorSpaceWithBasis = bases.VectorSpaceWithBasis
VectorInBasis = bases.VectorInBasis
BasisDirection = bases.BasisDirection
class VectorFunction(abc.ABC):
"""A function that acts on vectors."""
input_space: VectorSpaceWithBasis
output_space: VectorSpaceWithBasis
@abc.abstractmethod
def __call__(self, x: VectorInBasis) -> VectorInBasis:
"""Evaluates the function."""
class Linear(VectorFunction):
"""A linear function."""
def __init__(
self,
input_space: VectorSpaceWithBasis,
output_space: VectorSpaceWithBasis,
matrix: np.ndarray,
):
"""Initialises.
Args:
input_space: The input vector space.
output_space: The output vector space.
matrix: a [input, output] matrix acting in a (sorted) basis.
"""
self.input_space = input_space
self.output_space = output_space
self.matrix = matrix
def __post_init__(self) -> None:
output_size, input_size = self.matrix.shape
assert input_size == self.input_space.num_dims
assert output_size == self.output_space.num_dims
def __call__(self, x: VectorInBasis) -> VectorInBasis:
if x not in self.input_space:
raise TypeError(f"x={x} not in self.input_space={self.input_space}.")
return VectorInBasis(
basis_directions=sorted(self.output_space.basis),
magnitudes=x.magnitudes @ self.matrix,
)
@classmethod
def from_action(
cls,
input_space: VectorSpaceWithBasis,
output_space: VectorSpaceWithBasis,
action: Callable[[BasisDirection], VectorInBasis],
) -> "Linear":
"""from_action(i, o)(action) creates a Linear."""
matrix = np.zeros((input_space.num_dims, output_space.num_dims))
for i, direction in enumerate(input_space.basis):
out_vector = action(direction)
if out_vector not in output_space:
raise TypeError(f"image of {direction} from input_space={input_space} "
f"is not in output_space={output_space}")
matrix[i, :] = out_vector.magnitudes
return Linear(input_space, output_space, matrix)
@classmethod
def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear":
"""Combines multiple parallel linear functions into a single one."""
joint_input_space = bases.join_vector_spaces(
*[fn.input_space for fn in fns])
joint_output_space = bases.join_vector_spaces(
*[fn.output_space for fn in fns])
def action(x: bases.BasisDirection) -> bases.VectorInBasis:
out = joint_output_space.null_vector()
for fn in fns:
if x in fn.input_space:
x_vec = fn.input_space.vector_from_basis_direction(x)
out += fn(x_vec).project(joint_output_space)
return out
return cls.from_action(joint_input_space, joint_output_space, action)
def project(
from_space: VectorSpaceWithBasis,
to_space: VectorSpaceWithBasis,
) -> Linear:
"""Creates a projection."""
def action(direction: bases.BasisDirection) -> VectorInBasis:
if direction in to_space:
return to_space.vector_from_basis_direction(direction)
else:
return to_space.null_vector()
return Linear.from_action(from_space, to_space, action=action)
@dataclasses.dataclass
class ScalarBilinear:
"""A scalar-valued bilinear operator."""
left_space: VectorSpaceWithBasis
right_space: VectorSpaceWithBasis
matrix: np.ndarray
def __post_init__(self):
"""Ensure matrix acts in sorted bases and typecheck sizes."""
left_size, right_size = self.matrix.shape
assert left_size == self.left_space.num_dims
assert right_size == self.right_space.num_dims
def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float:
"""Describes the action of the operator on vectors."""
if x not in self.left_space:
raise TypeError(f"x={x} not in self.left_space={self.left_space}.")
if y not in self.right_space:
raise TypeError(f"y={y} not in self.right_space={self.right_space}.")
return (x.magnitudes.T @ self.matrix @ y.magnitudes).item()
@classmethod
def from_action(
cls,
left_space: VectorSpaceWithBasis,
right_space: VectorSpaceWithBasis,
action: Callable[[BasisDirection, BasisDirection], float],
) -> "ScalarBilinear":
"""from_action(l, r)(action) creates a ScalarBilinear."""
matrix = np.zeros((left_space.num_dims, right_space.num_dims))
for i, left_direction in enumerate(left_space.basis):
for j, right_direction in enumerate(right_space.basis):
matrix[i, j] = action(left_direction, right_direction)
return ScalarBilinear(left_space, right_space, matrix)