# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Inferring the vector spaces taken on by certain operations.""" import dataclasses import itertools from typing import Set import networkx as nx from tracr.compiler import nodes from tracr.craft import bases from tracr.rasp import rasp from tracr.utils import errors Node = nodes.Node @dataclasses.dataclass class InferBasesOutput: graph: nx.DiGraph def infer_bases( graph: nx.DiGraph, sink: Node, vocab: Set[rasp.Value], max_seq_len: int, ) -> None: """Infers in-place the possible output values and vector bases of the SOps.""" def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]: """Computes value set using already-computed predecessor value sets.""" if sop is rasp.tokens: return vocab elif sop is rasp.indices: return set(range(max_seq_len)) elif isinstance(sop, rasp.SelectorWidth): return set(range(0, max_seq_len + 1)) elif isinstance(sop, rasp.Full): return {sop.fill} elif isinstance(sop, rasp.Map): inner_value_set = graph.nodes[sop.inner.label][nodes.VALUE_SET] out = set() for x in inner_value_set: res = errors.ignoring_arithmetic_errors(sop.f)(x) if res is not None: out.add(res) return out elif isinstance(sop, rasp.SequenceMap): f_ignore_error = errors.ignoring_arithmetic_errors(sop.f) fst_value_set = graph.nodes[sop.fst.label][nodes.VALUE_SET] snd_value_set = graph.nodes[sop.snd.label][nodes.VALUE_SET] out = set() for l, r in itertools.product(fst_value_set, snd_value_set): res = f_ignore_error(l, r) if res is not None: out.add(res) return out elif isinstance(sop, rasp.Aggregate): if rasp.is_categorical(sop): # Simply pass on the value set of the underlying S-Op. return graph.nodes[sop.sop.label][nodes.VALUE_SET] elif rasp.is_numerical(sop): # TODO(b/255936408): This doesn't work if we average arbitrary values. # But most examples only average binary variables. sop_value_set = graph.nodes[sop.sop.label][nodes.VALUE_SET] if {int(x) for x in sop_value_set} != {0, 1}: raise NotImplementedError( "Attention patterns can currently only " "average binary variables. Not:", sop_value_set) value_set = set() for value in sop_value_set: for length in range(1, max_seq_len + 1): value_set.add(value / length) return value_set raise ValueError(f"Unsupported S-Op: {sop}") for node_id in nx.dfs_postorder_nodes(graph.reverse(), sink[nodes.ID]): expr = graph.nodes[node_id][nodes.EXPR] if not isinstance(expr, rasp.SOp): # Only S-Ops have output vector spaces. continue value_set = compute_value_set(expr) graph.nodes[node_id][nodes.VALUE_SET] = value_set if rasp.is_categorical(expr): out_space = bases.VectorSpaceWithBasis.from_values(expr.label, value_set) elif rasp.is_numerical(expr): out_space = bases.VectorSpaceWithBasis.from_names([expr.label]) else: raise ValueError(f"Unsupported S-Op type: {expr.type}") graph.nodes[node_id][nodes.OUTPUT_BASIS] = out_space.basis