# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A set of RASP programs and input/output pairs used in integration tests.""" from tracr.compiler import lib from tracr.rasp import rasp UNIVERSAL_TEST_CASES = [ dict( testcase_name="frac_prevs_1", program=lib.make_frac_prevs(rasp.tokens == "l"), vocab={"h", "e", "l", "o"}, test_input=list("hello"), expected_output=[0.0, 0.0, 1 / 3, 1 / 2, 2 / 5], max_seq_len=5), dict( testcase_name="frac_prevs_2", program=lib.make_frac_prevs(rasp.tokens == "("), vocab={"a", "b", "c", "(", ")"}, test_input=list("a()b(c))"), expected_output=[0.0, 1 / 2, 1 / 3, 1 / 4, 2 / 5, 2 / 6, 2 / 7, 2 / 8], max_seq_len=10), dict( testcase_name="frac_prevs_3", program=lib.make_frac_prevs(rasp.tokens == ")"), vocab={"a", "b", "c", "(", ")"}, test_input=list("a()b(c))"), expected_output=[0.0, 0.0, 1 / 3, 1 / 4, 1 / 5, 1 / 6, 2 / 7, 3 / 8], max_seq_len=10, ), dict( testcase_name="shift_by_one", program=lib.shift_by(1, rasp.tokens), vocab={"a", "b", "c", "d"}, test_input=list("abcd"), expected_output=[None, "a", "b", "c"], max_seq_len=5, ), dict( testcase_name="shift_by_two", program=lib.shift_by(2, rasp.tokens), vocab={"a", "b", "c", "d"}, test_input=list("abcd"), expected_output=[None, None, "a", "b"], max_seq_len=5, ), dict( testcase_name="detect_pattern_a", program=lib.detect_pattern(rasp.tokens, "a"), vocab={"a", "b", "c", "d"}, test_input=list("bacd"), expected_output=[False, True, False, False], max_seq_len=5, ), dict( testcase_name="detect_pattern_ab", program=lib.detect_pattern(rasp.tokens, "ab"), vocab={"a", "b"}, test_input=list("aaba"), expected_output=[None, False, True, False], max_seq_len=5, ), dict( testcase_name="detect_pattern_ab_2", program=lib.detect_pattern(rasp.tokens, "ab"), vocab={"a", "b"}, test_input=list("abaa"), expected_output=[None, True, False, False], max_seq_len=5, ), dict( testcase_name="detect_pattern_ab_3", program=lib.detect_pattern(rasp.tokens, "ab"), vocab={"a", "b"}, test_input=list("aaaa"), expected_output=[None, False, False, False], max_seq_len=5, ), dict( testcase_name="detect_pattern_abc", program=lib.detect_pattern(rasp.tokens, "abc"), vocab={"a", "b", "c"}, test_input=list("abcabc"), expected_output=[None, None, True, False, False, True], max_seq_len=6, ), ] TEST_CASES = UNIVERSAL_TEST_CASES + [ dict( testcase_name="reverse_1", program=lib.make_reverse(rasp.tokens), vocab={"a", "b", "c", "d"}, test_input=list("abcd"), expected_output=list("dcba"), max_seq_len=5), dict( testcase_name="reverse_2", program=lib.make_reverse(rasp.tokens), vocab={"a", "b", "c", "d"}, test_input=list("abc"), expected_output=list("cba"), max_seq_len=5), dict( testcase_name="reverse_3", program=lib.make_reverse(rasp.tokens), vocab={"a", "b", "c", "d"}, test_input=list("ad"), expected_output=list("da"), max_seq_len=5), dict( testcase_name="reverse_4", program=lib.make_reverse(rasp.tokens), vocab={"a", "b", "c", "d"}, test_input=["c"], expected_output=["c"], max_seq_len=5), dict( testcase_name="length_categorical_1", program=rasp.categorical(lib.make_length()), vocab={"a", "b", "c", "d"}, test_input=list("abc"), expected_output=[3, 3, 3], max_seq_len=3), dict( testcase_name="length_categorical_2", program=rasp.categorical(lib.make_length()), vocab={"a", "b", "c", "d"}, test_input=list("ad"), expected_output=[2, 2], max_seq_len=3), dict( testcase_name="length_categorical_3", program=rasp.categorical(lib.make_length()), vocab={"a", "b", "c", "d"}, test_input=["c"], expected_output=[1], max_seq_len=3), dict( testcase_name="length_numerical_1", program=rasp.numerical(lib.make_length()), vocab={"a", "b", "c", "d"}, test_input=list("abc"), expected_output=[3, 3, 3], max_seq_len=3), dict( testcase_name="length_numerical_2", program=rasp.numerical(lib.make_length()), vocab={"a", "b", "c", "d"}, test_input=list("ad"), expected_output=[2, 2], max_seq_len=3), dict( testcase_name="length_numerical_3", program=rasp.numerical(lib.make_length()), vocab={"a", "b", "c", "d"}, test_input=["c"], expected_output=[1], max_seq_len=3), dict( testcase_name="pair_balance_1", program=lib.make_pair_balance(rasp.tokens, "(", ")"), vocab={"a", "b", "c", "(", ")"}, test_input=list("a()b(c))"), expected_output=[0.0, 1 / 2, 0.0, 0.0, 1 / 5, 1 / 6, 0.0, -1 / 8], max_seq_len=10), dict( testcase_name="shuffle_dyck2_1", program=lib.make_shuffle_dyck(pairs=["()", "{}"]), vocab={"(", ")", "{", "}"}, test_input=list("({)}"), expected_output=[1, 1, 1, 1], max_seq_len=5), dict( testcase_name="shuffle_dyck2_2", program=lib.make_shuffle_dyck(pairs=["()", "{}"]), vocab={"(", ")", "{", "}"}, test_input=list("(){)}"), expected_output=[0, 0, 0, 0, 0], max_seq_len=5), dict( testcase_name="shuffle_dyck2_3", program=lib.make_shuffle_dyck(pairs=["()", "{}"]), vocab={"(", ")", "{", "}"}, test_input=list("{}("), expected_output=[0, 0, 0], max_seq_len=5), dict( testcase_name="shuffle_dyck3_1", program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), vocab={"(", ")", "{", "}", "[", "]"}, test_input=list("({)[}]"), expected_output=[1, 1, 1, 1, 1, 1], max_seq_len=6), dict( testcase_name="shuffle_dyck3_2", program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), vocab={"(", ")", "{", "}", "[", "]"}, test_input=list("(){)}"), expected_output=[0, 0, 0, 0, 0], max_seq_len=6), dict( testcase_name="shuffle_dyck3_3", program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), vocab={"(", ")", "{", "}", "[", "]"}, test_input=list("{}[(]"), expected_output=[0, 0, 0, 0, 0], max_seq_len=6), dict( testcase_name="hist", program=lib.make_hist(), vocab={"a", "b", "c", "d"}, test_input=list("abac"), expected_output=[2, 1, 2, 1], max_seq_len=5, ), dict( testcase_name="sort_unique_1", program=lib.make_sort_unique(vals=rasp.tokens, keys=rasp.tokens), vocab={1, 2, 3, 4}, test_input=[2, 4, 3, 1], expected_output=[1, 2, 3, 4], max_seq_len=5), dict( testcase_name="sort_unique_2", program=lib.make_sort_unique(vals=rasp.tokens, keys=1 - rasp.indices), vocab={"a", "b", "c", "d"}, test_input=list("abcd"), expected_output=["d", "c", "b", "a"], max_seq_len=5), dict( testcase_name="sort_1", program=lib.make_sort( vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1), vocab={1, 2, 3, 4}, test_input=[2, 4, 3, 1], expected_output=[1, 2, 3, 4], max_seq_len=5), dict( testcase_name="sort_2", program=lib.make_sort( vals=rasp.tokens, keys=1 - rasp.indices, max_seq_len=5, min_key=1), vocab={"a", "b", "c", "d"}, test_input=list("abcd"), expected_output=["d", "c", "b", "a"], max_seq_len=5), dict( testcase_name="sort_3", program=lib.make_sort( vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1), vocab={1, 2, 3, 4}, test_input=[2, 4, 1, 2], expected_output=[1, 2, 2, 4], max_seq_len=5), dict( testcase_name="sort_freq", program=lib.make_sort_freq(max_seq_len=5), vocab={1, 2, 3, 4}, test_input=[2, 4, 2, 1], expected_output=[2, 2, 4, 1], max_seq_len=5), dict( testcase_name="make_count_less_freq_categorical_1", program=lib.make_count_less_freq(n=2), vocab={"a", "b", "c", "d"}, test_input=["a", "a", "a", "b", "b", "c"], expected_output=[3, 3, 3, 3, 3, 3], max_seq_len=6), dict( testcase_name="make_count_less_freq_categorical_2", program=lib.make_count_less_freq(n=2), vocab={"a", "b", "c", "d"}, test_input=["a", "a", "c", "b", "b", "c"], expected_output=[6, 6, 6, 6, 6, 6], max_seq_len=6), dict( testcase_name="make_count_less_freq_numerical_1", program=rasp.numerical(lib.make_count_less_freq(n=2)), vocab={"a", "b", "c", "d"}, test_input=["a", "a", "a", "b", "b", "c"], expected_output=[3, 3, 3, 3, 3, 3], max_seq_len=6), dict( testcase_name="make_count_less_freq_numerical_2", program=rasp.numerical(lib.make_count_less_freq(n=2)), vocab={"a", "b", "c", "d"}, test_input=["a", "a", "c", "b", "b", "c"], expected_output=[6, 6, 6, 6, 6, 6], max_seq_len=6), dict( testcase_name="make_count_1", program=lib.make_count(rasp.tokens, "a"), vocab={"a", "b", "c"}, test_input=["a", "a", "a", "b", "b", "c"], expected_output=[3, 3, 3, 3, 3, 3], max_seq_len=8, ), dict( testcase_name="make_count_2", program=lib.make_count(rasp.tokens, "a"), vocab={"a", "b", "c"}, test_input=["c", "a", "b", "c"], expected_output=[1, 1, 1, 1], max_seq_len=8, ), dict( testcase_name="make_count_3", program=lib.make_count(rasp.tokens, "a"), vocab={"a", "b", "c"}, test_input=["b", "b", "c"], expected_output=[0, 0, 0], max_seq_len=8, ), dict( testcase_name="make_nary_sequencemap_1", program=lib.make_nary_sequencemap( lambda x, y, z: x + y - z, rasp.tokens, rasp.tokens, rasp.indices), vocab={1, 2, 3}, test_input=[1, 2, 3], expected_output=[2, 3, 4], max_seq_len=5, ), dict( testcase_name="make_nary_sequencemap_2", program=lib.make_nary_sequencemap( lambda x, y, z: x * y / z, rasp.indices, rasp.indices, rasp.tokens), vocab={1, 2, 3}, test_input=[1, 2, 3], expected_output=[0, 1 / 2, 4 / 3], max_seq_len=3, ) ] # make_nary_sequencemap(f, *sops) CAUSAL_TEST_CASES = UNIVERSAL_TEST_CASES + [ dict( testcase_name="selector_width", program=rasp.SelectorWidth( rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)), vocab={"a", "b", "c", "d"}, test_input=list("abcd"), expected_output=[1, 2, 3, 4], max_seq_len=5), ]