# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Integration tests for the RASP -> craft stages of the compiler.""" import unittest from absl.testing import absltest from absl.testing import parameterized import numpy as np from tracr.compiler import basis_inference from tracr.compiler import craft_graph_to_model from tracr.compiler import expr_to_craft_graph from tracr.compiler import nodes from tracr.compiler import rasp_to_graph from tracr.compiler import test_cases from tracr.craft import bases from tracr.craft import tests_common from tracr.rasp import rasp _BOS_DIRECTION = "rasp_to_transformer_integration_test_BOS" _ONE_DIRECTION = "rasp_to_craft_integration_test_ONE" def _make_input_space(vocab, max_seq_len): tokens_space = bases.VectorSpaceWithBasis.from_values("tokens", vocab) indices_space = bases.VectorSpaceWithBasis.from_values( "indices", range(max_seq_len)) one_space = bases.VectorSpaceWithBasis.from_names([_ONE_DIRECTION]) bos_space = bases.VectorSpaceWithBasis.from_names([_BOS_DIRECTION]) input_space = bases.join_vector_spaces(tokens_space, indices_space, one_space, bos_space) return input_space def _embed_input(input_seq, input_space): bos_vec = input_space.vector_from_basis_direction( bases.BasisDirection(_BOS_DIRECTION)) one_vec = input_space.vector_from_basis_direction( bases.BasisDirection(_ONE_DIRECTION)) embedded_input = [bos_vec + one_vec] for i, val in enumerate(input_seq): i_vec = input_space.vector_from_basis_direction( bases.BasisDirection("indices", i)) val_vec = input_space.vector_from_basis_direction( bases.BasisDirection("tokens", val)) embedded_input.append(i_vec + val_vec + one_vec) return bases.VectorInBasis.stack(embedded_input) def _embed_output(output_seq, output_space, categorical_output): embedded_output = [] output_label = output_space.basis[0].name for x in output_seq: if x is None: out_vec = output_space.null_vector() elif categorical_output: out_vec = output_space.vector_from_basis_direction( bases.BasisDirection(output_label, x)) else: out_vec = x * output_space.vector_from_basis_direction( output_space.basis[0]) embedded_output.append(out_vec) return bases.VectorInBasis.stack(embedded_output) class CompilerIntegrationTest(tests_common.VectorFnTestCase): @parameterized.named_parameters( dict( testcase_name="map", program=rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))), dict( testcase_name="sequence_map", program=rasp.categorical( rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.indices))), dict( testcase_name="sequence_map_with_same_input", program=rasp.categorical( rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens))), dict( testcase_name="select_aggregate", program=rasp.Aggregate( rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), rasp.Map(lambda x: 1, rasp.tokens)))) def test_rasp_program_and_craft_model_produce_same_output(self, program): vocab = {0, 1, 2} max_seq_len = 3 extracted = rasp_to_graph.extract_rasp_graph(program) basis_inference.infer_bases( extracted.graph, extracted.sink, vocab, max_seq_len=max_seq_len, ) expr_to_craft_graph.add_craft_components_to_rasp_graph( extracted.graph, bos_dir=bases.BasisDirection(_BOS_DIRECTION), one_dir=bases.BasisDirection(_ONE_DIRECTION), ) model = craft_graph_to_model.craft_graph_to_model(extracted.graph, extracted.sources) input_space = _make_input_space(vocab, max_seq_len) output_space = bases.VectorSpaceWithBasis( extracted.sink[nodes.OUTPUT_BASIS]) for val in vocab: test_input = _embed_input([val], input_space) rasp_output = program([val]) expected_output = _embed_output( output_seq=rasp_output, output_space=output_space, categorical_output=True) test_output = model.apply(test_input).project(output_space) self.assertVectorAllClose( tests_common.strip_bos_token(test_output), expected_output) @parameterized.named_parameters(*test_cases.TEST_CASES) def test_compiled_models_produce_expected_output(self, program, vocab, test_input, expected_output, max_seq_len, **kwargs): del kwargs categorical_output = rasp.is_categorical(program) extracted = rasp_to_graph.extract_rasp_graph(program) basis_inference.infer_bases( extracted.graph, extracted.sink, vocab, max_seq_len=max_seq_len, ) expr_to_craft_graph.add_craft_components_to_rasp_graph( extracted.graph, bos_dir=bases.BasisDirection(_BOS_DIRECTION), one_dir=bases.BasisDirection(_ONE_DIRECTION), ) model = craft_graph_to_model.craft_graph_to_model(extracted.graph, extracted.sources) input_space = _make_input_space(vocab, max_seq_len) output_space = bases.VectorSpaceWithBasis( extracted.sink[nodes.OUTPUT_BASIS]) if not categorical_output: self.assertLen(output_space.basis, 1) test_input_vector = _embed_input(test_input, input_space) expected_output_vector = _embed_output( output_seq=expected_output, output_space=output_space, categorical_output=categorical_output) test_output = model.apply(test_input_vector).project(output_space) self.assertVectorAllClose( tests_common.strip_bos_token(test_output), expected_output_vector) @unittest.expectedFailure def test_setting_default_values_can_lead_to_wrong_outputs_in_compiled_model( self, program): # This is an example program in which setting a default value for aggregate # writes a value to the bos token position, which interfers with a later # aggregate operation causing the compiled model to have the wrong output. vocab = {"a", "b"} test_input = ["a"] max_seq_len = 2 # RASP: [False, True] # compiled: [False, False, True] not_a = rasp.Map(lambda x: x != "a", rasp.tokens) # RASP: # [[True, False], # [False, False]] # compiled: # [[False,True, False], # [True, False, False]] sel1 = rasp.Select(rasp.tokens, rasp.tokens, lambda k, q: k == "a" and q == "a") # RASP: [False, True] # compiled: [True, False, True] agg1 = rasp.Aggregate(sel1, not_a, default=True) # RASP: # [[False, True] # [True, True]] # compiled: # [[True, False, False] # [True, False, False]] # because pre-softmax we get # [[1.5, 1, 1] # [1.5, 1, 1]] # instead of # [[0.5, 1, 1] # [0.5, 1, 1]] # Because agg1 = True is stored on the BOS token position sel2 = rasp.Select(agg1, agg1, lambda k, q: k or q) # RASP: [1, 0.5] # compiled # [1, 1, 1] program = rasp.numerical( rasp.Aggregate(sel2, rasp.numerical(not_a), default=1)) expected_output = [1, 0.5] # RASP program gives the correct output program_output = program(test_input) np.testing.assert_allclose(program_output, expected_output) extracted = rasp_to_graph.extract_rasp_graph(program) basis_inference.infer_bases( extracted.graph, extracted.sink, vocab, max_seq_len=max_seq_len, ) expr_to_craft_graph.add_craft_components_to_rasp_graph( extracted.graph, bos_dir=bases.BasisDirection(_BOS_DIRECTION), one_dir=bases.BasisDirection(_ONE_DIRECTION), ) model = craft_graph_to_model.craft_graph_to_model(extracted.graph, extracted.sources) input_space = _make_input_space(vocab, max_seq_len) output_space = bases.VectorSpaceWithBasis( extracted.sink[nodes.OUTPUT_BASIS]) test_input_vector = _embed_input(test_input, input_space) expected_output_vector = _embed_output( output_seq=expected_output, output_space=output_space, categorical_output=True) compiled_model_output = model.apply(test_input_vector).project(output_space) # Compiled craft model gives correct output self.assertVectorAllClose( tests_common.strip_bos_token(compiled_model_output), expected_output_vector) if __name__ == "__main__": absltest.main()