Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Functions on vector spaces.""" | |
import abc | |
import dataclasses | |
from typing import Callable, Sequence | |
import numpy as np | |
from tracr.craft import bases | |
VectorSpaceWithBasis = bases.VectorSpaceWithBasis | |
VectorInBasis = bases.VectorInBasis | |
BasisDirection = bases.BasisDirection | |
class VectorFunction(abc.ABC): | |
"""A function that acts on vectors.""" | |
input_space: VectorSpaceWithBasis | |
output_space: VectorSpaceWithBasis | |
def __call__(self, x: VectorInBasis) -> VectorInBasis: | |
"""Evaluates the function.""" | |
class Linear(VectorFunction): | |
"""A linear function.""" | |
def __init__( | |
self, | |
input_space: VectorSpaceWithBasis, | |
output_space: VectorSpaceWithBasis, | |
matrix: np.ndarray, | |
): | |
"""Initialises. | |
Args: | |
input_space: The input vector space. | |
output_space: The output vector space. | |
matrix: a [input, output] matrix acting in a (sorted) basis. | |
""" | |
self.input_space = input_space | |
self.output_space = output_space | |
self.matrix = matrix | |
def __post_init__(self) -> None: | |
output_size, input_size = self.matrix.shape | |
assert input_size == self.input_space.num_dims | |
assert output_size == self.output_space.num_dims | |
def __call__(self, x: VectorInBasis) -> VectorInBasis: | |
if x not in self.input_space: | |
raise TypeError(f"x={x} not in self.input_space={self.input_space}.") | |
return VectorInBasis( | |
basis_directions=sorted(self.output_space.basis), | |
magnitudes=x.magnitudes @ self.matrix, | |
) | |
def from_action( | |
cls, | |
input_space: VectorSpaceWithBasis, | |
output_space: VectorSpaceWithBasis, | |
action: Callable[[BasisDirection], VectorInBasis], | |
) -> "Linear": | |
"""from_action(i, o)(action) creates a Linear.""" | |
matrix = np.zeros((input_space.num_dims, output_space.num_dims)) | |
for i, direction in enumerate(input_space.basis): | |
out_vector = action(direction) | |
if out_vector not in output_space: | |
raise TypeError(f"image of {direction} from input_space={input_space} " | |
f"is not in output_space={output_space}") | |
matrix[i, :] = out_vector.magnitudes | |
return Linear(input_space, output_space, matrix) | |
def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear": | |
"""Combines multiple parallel linear functions into a single one.""" | |
joint_input_space = bases.join_vector_spaces( | |
*[fn.input_space for fn in fns]) | |
joint_output_space = bases.join_vector_spaces( | |
*[fn.output_space for fn in fns]) | |
def action(x: bases.BasisDirection) -> bases.VectorInBasis: | |
out = joint_output_space.null_vector() | |
for fn in fns: | |
if x in fn.input_space: | |
x_vec = fn.input_space.vector_from_basis_direction(x) | |
out += fn(x_vec).project(joint_output_space) | |
return out | |
return cls.from_action(joint_input_space, joint_output_space, action) | |
def project( | |
from_space: VectorSpaceWithBasis, | |
to_space: VectorSpaceWithBasis, | |
) -> Linear: | |
"""Creates a projection.""" | |
def action(direction: bases.BasisDirection) -> VectorInBasis: | |
if direction in to_space: | |
return to_space.vector_from_basis_direction(direction) | |
else: | |
return to_space.null_vector() | |
return Linear.from_action(from_space, to_space, action=action) | |
class ScalarBilinear: | |
"""A scalar-valued bilinear operator.""" | |
left_space: VectorSpaceWithBasis | |
right_space: VectorSpaceWithBasis | |
matrix: np.ndarray | |
def __post_init__(self): | |
"""Ensure matrix acts in sorted bases and typecheck sizes.""" | |
left_size, right_size = self.matrix.shape | |
assert left_size == self.left_space.num_dims | |
assert right_size == self.right_space.num_dims | |
def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float: | |
"""Describes the action of the operator on vectors.""" | |
if x not in self.left_space: | |
raise TypeError(f"x={x} not in self.left_space={self.left_space}.") | |
if y not in self.right_space: | |
raise TypeError(f"y={y} not in self.right_space={self.right_space}.") | |
return (x.magnitudes.T @ self.matrix @ y.magnitudes).item() | |
def from_action( | |
cls, | |
left_space: VectorSpaceWithBasis, | |
right_space: VectorSpaceWithBasis, | |
action: Callable[[BasisDirection, BasisDirection], float], | |
) -> "ScalarBilinear": | |
"""from_action(l, r)(action) creates a ScalarBilinear.""" | |
matrix = np.zeros((left_space.num_dims, right_space.num_dims)) | |
for i, left_direction in enumerate(left_space.basis): | |
for j, right_direction in enumerate(right_space.basis): | |
matrix[i, j] = action(left_direction, right_direction) | |
return ScalarBilinear(left_space, right_space, matrix) | |