# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Add craft model blocks to graph of RASPExpr.""" from typing import Any, Callable, Optional import networkx as nx from tracr.compiler import nodes from tracr.craft import bases from tracr.craft.chamber import categorical_attn from tracr.craft.chamber import categorical_mlp from tracr.craft.chamber import numerical_mlp from tracr.craft.chamber import selector_width from tracr.rasp import rasp def _transform_fun_to_basis_fun( fun: Callable[..., Any], output_direction_name: Optional[str] = None) -> Callable[..., Any]: """Transforms a function acting on values into one acting on directions.""" def bases_fun(*args): values = [d.value for d in args] result = fun(*values) if output_direction_name: return bases.BasisDirection(output_direction_name, result) return result return bases_fun def _check_selector_expression(expr, graph): """Check graph structure and encodings for an aggregate or selector width.""" sel_expr = expr.selector # Check graph structure assert sel_expr.label in graph.predecessors(expr.label) assert sel_expr.keys.label in graph.predecessors(sel_expr.label) assert sel_expr.queries.label in graph.predecessors(sel_expr.label) if (not rasp.is_categorical(sel_expr.queries) or not rasp.is_categorical(sel_expr.keys)): raise ValueError("Selector keys and queries must be categorical.") def add_craft_components_to_rasp_graph( graph: nx.DiGraph, bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"), one_dir: bases.BasisDirection = bases.BasisDirection("one"), causal: bool = False, mlp_exactness: float = 100, ) -> None: """Translates expressions to craft blocks and attaches them to the graph. Sets the `MODEL_BLOCK` attribute for all nodes in `graph`. Args: graph: RASP graph with `VALUE_SET` but not `MODEL_BLOCK` attributes. bos_dir: Basis direction representing beginning of sequence (bos) token. one_dir: Auxiliary basis direction that must contain 1. causal: If True, marks attention blocks as causal. mlp_exactness: Controls the approximation of the MLP layers. Raises: ValueError: On invalid input (if `MODEL_BLOCK` is set already, or `VALUE_SET` is not set already) NotImplementedError: If the graph contains an unsupported expression. """ one_space = bases.VectorSpaceWithBasis([one_dir]) for node_id, node in graph.nodes.items(): expr = node[nodes.EXPR] if not isinstance(expr, rasp.SOp): continue if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]: raise ValueError("Input graph cannot have model blocks set already.") if nodes.VALUE_SET not in node: raise ValueError( "Craft components can only be added after basis inference.") if expr is rasp.tokens or expr is rasp.indices: block = None elif isinstance(expr, rasp.Map): inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label] assert inner_expr.label in graph.predecessors(node_id) input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS]) output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr): basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label) block = categorical_mlp.map_categorical_mlp( input_space=input_space, output_space=output_space, operation=basis_fun) elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr): block = categorical_mlp.map_categorical_to_numerical_mlp( input_space=input_space, output_space=output_space, operation=expr.f, ) elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr): block = numerical_mlp.map_numerical_to_categorical_mlp( f=expr.f, input_space=input_space, output_space=output_space, input_value_set=inner_node[nodes.VALUE_SET], one_space=one_space, hidden_name=f"_hidden_{expr.label}_", large_number=mlp_exactness) elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr): block = numerical_mlp.map_numerical_mlp( f=expr.f, input_space=input_space, output_space=output_space, input_value_set=inner_node[nodes.VALUE_SET], one_space=one_space, hidden_name=f"_hidden_{expr.label}_", large_number=mlp_exactness) else: raise NotImplementedError("Map does no support " f"in_type '{inner_expr.type}' and" f" out_type '{expr.type}'!") elif isinstance(expr, rasp.SequenceMap): fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label] snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label] # Check graph structure assert fst_expr.label in graph.predecessors(node_id) assert snd_expr.label in graph.predecessors(node_id) fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS]) snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS]) out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) if (isinstance(expr, rasp.LinearSequenceMap) and not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))): raise NotImplementedError("Linear SequenceMap only supports numerical " "inputs/outputs.") elif ( not isinstance(expr, rasp.LinearSequenceMap) and not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))): raise NotImplementedError("(Non-linear) SequenceMap only supports " "categorical inputs/outputs.") if isinstance(expr, rasp.LinearSequenceMap): assert len(fst_space.basis) == 1 assert len(snd_space.basis) == 1 assert len(out_space.basis) == 1 block = numerical_mlp.linear_sequence_map_numerical_mlp( input1_basis_direction=fst_space.basis[0], input2_basis_direction=snd_space.basis[0], output_basis_direction=out_space.basis[0], input1_factor=expr.fst_fac, input2_factor=expr.snd_fac, hidden_name=f"_hidden_{expr.label}_") elif fst_space == snd_space: # It's okay to use the local variable expr.f because it is # only used within the same loop iteration to create the MLP. # pylint: disable=cell-var-from-loop basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x), expr.label) block = categorical_mlp.map_categorical_mlp( input_space=fst_space, output_space=out_space, operation=basis_fun) else: basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label) block = categorical_mlp.sequence_map_categorical_mlp( input1_space=fst_space, input2_space=snd_space, output_space=out_space, operation=basis_fun, one_space=one_space, hidden_name=f"_hidden_{expr.label}_") elif isinstance(expr, rasp.Aggregate): sel_expr: rasp.Select = expr.selector agg_expr: rasp.Aggregate = expr if not isinstance(sel_expr, rasp.Select): raise TypeError("Compiling composite Selectors is not supported. " f"Got a {sel_expr}.") queries = graph.nodes[sel_expr.queries.label] keys = graph.nodes[sel_expr.keys.label] sop = graph.nodes[agg_expr.sop.label] _check_selector_expression(expr, graph) assert agg_expr.sop.label in graph.predecessors(node_id) if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr): raise ValueError( "sop encoding must match output encoding of the aggregate.") if rasp.is_categorical(agg_expr) and agg_expr.default is not None: raise ValueError("Default for a categorical aggregate must be None. " f"Got {agg_expr.default}") if rasp.is_numerical(agg_expr) and agg_expr.default != 0: raise ValueError("Default for a numerical aggregate must be 0. " f"Got {agg_expr.default}") bos_space = bases.VectorSpaceWithBasis([bos_dir]) one_space = bases.VectorSpaceWithBasis([one_dir]) query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS]) key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS]) value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS]) output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) # Argument order is different in craft / transformers than RASP selectors def attn_basis_fn(query: bases.BasisDirection, key: bases.BasisDirection) -> bool: # It's okay to use the local variable sel_expr because this function is # only used within the same loop iteration to create an attention head. # pylint: disable=cell-var-from-loop selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) return selector_basis_fn(key, query) block = categorical_attn.categorical_attn( query_space=query_space, key_space=key_space, value_space=value_space, output_space=output_space, bos_space=bos_space, one_space=one_space, attn_fn=attn_basis_fn, default_output=output_space.null_vector(), causal=causal, always_attend_to_bos=False, use_bos_for_default_output=True, softmax_coldness=100) elif isinstance(expr, rasp.SelectorWidth): sel_expr = expr.selector queries = graph.nodes[sel_expr.queries.label] keys = graph.nodes[sel_expr.keys.label] _check_selector_expression(expr, graph) bos_space = bases.VectorSpaceWithBasis([bos_dir]) query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS]) key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS]) output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) # Argument order is different in craft / transformers than RASP selectors def attn_basis_fn(query: bases.BasisDirection, key: bases.BasisDirection) -> bool: # It's okay to use the local variable sel_expr because this function is # only used within the same loop iteration to create an attention head. selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) # pylint: disable=cell-var-from-loop return selector_basis_fn(key, query) block = selector_width.selector_width( query_space=query_space, key_space=key_space, output_space=output_space, bos_space=bos_space, one_space=one_space, attn_fn=attn_basis_fn, out_value_set=node[nodes.VALUE_SET], categorical_output=rasp.is_categorical(expr), causal=False, softmax_coldness=100, mlp_large_number=mlp_exactness, label=expr.label) else: raise NotImplementedError(f"Expression {expr} cannot be translated to " "a model component.") graph.nodes[node_id][nodes.MODEL_BLOCK] = block