# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for compiler.craft_graph_to_model.""" from absl.testing import absltest from absl.testing import parameterized import networkx as nx from tracr.compiler import craft_graph_to_model from tracr.compiler import nodes from tracr.compiler import rasp_to_graph from tracr.craft import bases from tracr.craft.chamber import categorical_attn from tracr.craft.chamber import categorical_mlp from tracr.rasp import rasp class CraftAllocateModulesToLayersTest(parameterized.TestCase): def _get_dummy_block(self, block_type): if block_type == "ATTN": return categorical_attn.categorical_attn( query_space=bases.VectorSpaceWithBasis.from_names(["query"]), key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]), value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]), output_space=bases.VectorSpaceWithBasis.from_names(["output"]), bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]), one_space=bases.VectorSpaceWithBasis.from_names(["one"]), attn_fn=lambda x, y: True, ) elif block_type == "MLP": return categorical_mlp.map_categorical_mlp( input_space=bases.VectorSpaceWithBasis.from_names(["input"]), output_space=bases.VectorSpaceWithBasis.from_names(["output"]), operation=lambda x: x, ) else: return None def test_get_longest_path_length_to_node_returns_expected_result(self): """Creates a graph and checks the longest path for each node.""" # Node IDs: # 0 -- 1 -- 2 -- 3 ------------ 4 # / / # 5 -- 6 ---------- 7 -- 8 -- 9 # # 10 # Expected return values: # 0 -- 1 -- 2 -- 3 ------------ 5 # / / # 0 -- 1 ---------- 2 -- 3 -- 4 # # -1 graph = nx.DiGraph() node_ids = list(range(11)) expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1] for node_id, res in zip(node_ids, expected_results): graph.add_node( node_id, **{ nodes.ID: node_id, nodes.EXPR: rasp.ConstantSOp(1), "expected_result": res }) graph.add_edge(0, 1) graph.add_edge(1, 2) graph.add_edge(2, 3) graph.add_edge(3, 4) graph.add_edge(5, 6) graph.add_edge(6, 7) graph.add_edge(7, 8) graph.add_edge(8, 9) graph.add_edge(6, 3) graph.add_edge(9, 4) sources = [graph.nodes[0], graph.nodes[5]] for node_id, node in graph.nodes.items(): result = craft_graph_to_model._get_longest_path_length_to_node( graph, sources, node) self.assertEqual(result, node["expected_result"]) def test_allocate_modules_to_layers_returns_expected_result(self): """Creates a graph and checks if the correct layer assignment is returned.""" # Computation Graph: # INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT # / / / # INPUT -- MLP --- MLP ATTN # \ / # ATTN # Node IDs: # 0 -- 1 -- 2 -- 3 -- 4 -- 5 # / / / # 6 -- 7 ---- 8 9 # \ / # 10 # Expected layer allocation: # -1 -- 0 -- 3 -- 4 -- 7 -- -1 # / / / # -1 -- 1 --- 3 6 # \ / # 4 graph = nx.DiGraph() node_ids = list(range(11)) types = [ "INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP", "ATTN", "ATTN" ] expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4] for node_id, node_type, res in zip(node_ids, types, expected_results): graph.add_node( node_id, **{ nodes.ID: node_id, nodes.EXPR: rasp.ConstantSOp(1), nodes.MODEL_BLOCK: self._get_dummy_block(node_type), "expected_result": res }) graph.add_edge(0, 1) graph.add_edge(1, 2) graph.add_edge(2, 3) graph.add_edge(3, 4) graph.add_edge(4, 5) graph.add_edge(6, 7) graph.add_edge(7, 2) graph.add_edge(7, 8) graph.add_edge(8, 3) graph.add_edge(8, 10) graph.add_edge(9, 4) graph.add_edge(10, 9) craft_graph = rasp_to_graph.ExtractRaspGraphOutput( graph=graph, sink=graph.nodes[10], sources=[graph.nodes[0], graph.nodes[6]]) layer_allocation = craft_graph_to_model._allocate_modules_to_layers( craft_graph.graph, craft_graph.sources) for node_id, node in graph.nodes.items(): self.assertEqual(layer_allocation[node_id], node["expected_result"]) def test_allocate_modules_to_layers_returns_expected_result_for_chain(self): """Tests a chain of alternating attention layers and MLPs.""" # Computation Graph: # INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT # Node IDs: # 0 -- 1 -- 2 -- 3 -- 4 -- 5 # Expected layer allocation: # -1 -- 0 -- 1 -- 2 -- 3 -- -1 graph = nx.DiGraph() node_ids = list(range(11)) types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"] expected_results = [-1, 0, 1, 2, 3, -1] for node_id, node_type, res in zip(node_ids, types, expected_results): graph.add_node( node_id, **{ nodes.ID: node_id, nodes.EXPR: rasp.ConstantSOp(1), nodes.MODEL_BLOCK: self._get_dummy_block(node_type), "expected_result": res }) graph.add_edge(0, 1) graph.add_edge(1, 2) graph.add_edge(2, 3) graph.add_edge(3, 4) graph.add_edge(4, 5) craft_graph = rasp_to_graph.ExtractRaspGraphOutput( graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]]) layer_allocation = craft_graph_to_model._allocate_modules_to_layers( craft_graph.graph, craft_graph.sources) for node_id, node in graph.nodes.items(): self.assertEqual(layer_allocation[node_id], node["expected_result"]) if __name__ == "__main__": absltest.main()