# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Create a craft model from a computational graph.""" import collections from typing import Dict, List, Sequence import networkx as nx from tracr.compiler import nodes from tracr.craft import bases from tracr.craft import transformers from tracr.rasp import rasp Node = nodes.Node NodeID = nodes.NodeID def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node], node: Node) -> int: """Returns the lengths of the longest path from sources to node. Only SOps count towards the length of a path. Args: graph: DAG to compute longest path in. sources: List of starting nodes, longest path will be a maximum over all. node: Target node. Returns: Number of steps needed for the longest path from the source to the node, or -1 if there is no path from any of the sources to the target node. """ if node in sources: return 0 def num_sops(path: Sequence[NodeID]) -> int: num = 0 for node_id in path: if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp): num += 1 return num result = -1 for source in sources: all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID]) longest_path_len = max(map(num_sops, all_paths), default=-1) - 1 if longest_path_len > result: result = longest_path_len return result def _node_is_attn(node: Node) -> bool: """Returns True if node is an attention layer.""" return nodes.MODEL_BLOCK in node and isinstance( node[nodes.MODEL_BLOCK], (transformers.AttentionHead, transformers.MultiAttentionHead)) def _node_is_mlp(node: Node) -> bool: """Returns True if node is an MLP layer.""" return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK], transformers.MLP) def _node_is_residual_block(node: Node) -> bool: """Returns True if node is a valid residual block (Attn followed by MLP).""" block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None if block and isinstance(block, transformers.SeriesWithResiduals): if len(block.blocks) == 2: attn, mlp = block.blocks if (isinstance( attn, (transformers.AttentionHead, transformers.MultiAttentionHead)) and isinstance(mlp, transformers.MLP)): return True return False def _all_attn_nodes(node_list: Sequence[Node]) -> bool: """Returns True iff all nodes are attention layers (or nodes is empty).""" for node in node_list: if not _node_is_attn(node): return False return True def _all_mlp_nodes(node_list: Sequence[Node]) -> bool: """Returns True iff all nodes are MLP layers (or nodes is empty).""" for node in node_list: if not _node_is_mlp(node): return False return True def _allocate_modules_to_layers(graph: nx.DiGraph, sources: Sequence[Node]) -> Dict[int, int]: """Allocate all nodes in compute graph to layers. First, computes the longest path from the input to each node that is a model component (not input and output nodes). The longest path to a model component (its "depth") determines a layer in which we can place it while ensuring that all necessary previous computations have already happened. This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...] In the special case where there are only Attention layers at one depth level and only MLP layers in the next depth layer, they are treated as if there are at the same depth because attention layers always come before MLP layers for the same depth. Args: graph: RASP graph with craft blocks. sources: List of input nodes Returns: A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ... are in the order attention, mlp, attention, mlp, ... """ layer_allocation: Dict[int, int] = collections.defaultdict(lambda: -1) depth_by_node_id: Dict[int, int] = dict() nodes_by_depth: Dict[int, List[Node]] = collections.defaultdict(list) # Compute depth of all model components (longest path from source to node) for node_id, node in graph.nodes.items(): if (_node_is_mlp(node) or _node_is_attn(node) or _node_is_residual_block(node)): # Node is a model component longest_path_len = _get_longest_path_length_to_node(graph, sources, node) depth_by_node_id[node_id] = longest_path_len nodes_by_depth[longest_path_len].append(node) # If at level `depth` there are only attention heads and at level `depths + 1` # there are only MLPs, we can condense them into one level # TODO(b/255936816): Think about improving this heuristic. The heuristic is # not optimal, and only catches very basic opportunities for optimization. It # is easy to come up with opportunities for optimization that it does not # catch. min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys()) depth = min_depth while depth < max_depth: if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes( nodes_by_depth[depth + 1]): # Condense by decrementing the depth of all nodes starting from depth+1 for update_depth in range(depth + 1, max_depth + 1): for node in nodes_by_depth[update_depth]: node_id = node[nodes.ID] depth_by_node_id[node_id] = update_depth - 1 nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth]) nodes_by_depth[update_depth] = [] max_depth -= 1 depth += 1 # Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ... current_layer = 0 current_depth = 1 for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]): while depth > current_depth: current_depth += 1 current_layer += 2 if depth == current_depth: if _node_is_residual_block(graph.nodes[node_id]): layer_allocation[node_id] = current_layer else: is_mlp = _node_is_mlp(graph.nodes[node_id]) layer_allocation[node_id] = current_layer + int(is_mlp) return layer_allocation def craft_graph_to_model( graph: nx.DiGraph, sources: Sequence[Node]) -> transformers.SeriesWithResiduals: """Translates a RASP graph with craft blocks into a full craft model. 1. Allocate modules to layers, assuming layers in the order 2. Creates subspaces for all inputs and outputs, and builds residual stream. 3. Assembles everything into a craft model and returns it. Args: graph: RASP graph with craft blocks. sources: List of input nodes Returns: A craft model that can be compiled to model weights. Raises: ValueError: On invalid input (if the craft_graph does not have craft blocks already specified) """ layer_allocation = _allocate_modules_to_layers(graph, sources) blocks_by_layer = collections.defaultdict(list) model_blocks = [] residual_space = bases.VectorSpaceWithBasis([]) for node_id, layer_no in layer_allocation.items(): node = graph.nodes[node_id] block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None if _node_is_residual_block(node): assert isinstance(block, transformers.SeriesWithResiduals) assert len(block.blocks) == 2 residual_space = bases.join_vector_spaces(residual_space, block.blocks[0].residual_space, block.blocks[1].residual_space) blocks_by_layer[layer_no].append(block.blocks[0]) blocks_by_layer[layer_no + 1].append(block.blocks[1]) elif block: residual_space = bases.join_vector_spaces( residual_space, node[nodes.MODEL_BLOCK].residual_space) blocks_by_layer[layer_no].append(block) for layer_no, layer_blocks in sorted( blocks_by_layer.items(), key=lambda x: x[0]): for block in layer_blocks: block.residual_space = residual_space if layer_blocks: if layer_no % 2 == 0: # Attention Layer multi_head_attn = transformers.MultiAttentionHead(layer_blocks) model_blocks.append(multi_head_attn) else: # MLP Layer parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks) model_blocks.append(parallel_mlp) return transformers.SeriesWithResiduals(model_blocks)