Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
raw history blame
No virus
12.1 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A set of RASP programs and input/output pairs used in integration tests."""
from tracr.compiler import lib
from tracr.rasp import rasp
UNIVERSAL_TEST_CASES = [
dict(
testcase_name="frac_prevs_1",
program=lib.make_frac_prevs(rasp.tokens == "l"),
vocab={"h", "e", "l", "o"},
test_input=list("hello"),
expected_output=[0.0, 0.0, 1 / 3, 1 / 2, 2 / 5],
max_seq_len=5),
dict(
testcase_name="frac_prevs_2",
program=lib.make_frac_prevs(rasp.tokens == "("),
vocab={"a", "b", "c", "(", ")"},
test_input=list("a()b(c))"),
expected_output=[0.0, 1 / 2, 1 / 3, 1 / 4, 2 / 5, 2 / 6, 2 / 7, 2 / 8],
max_seq_len=10),
dict(
testcase_name="frac_prevs_3",
program=lib.make_frac_prevs(rasp.tokens == ")"),
vocab={"a", "b", "c", "(", ")"},
test_input=list("a()b(c))"),
expected_output=[0.0, 0.0, 1 / 3, 1 / 4, 1 / 5, 1 / 6, 2 / 7, 3 / 8],
max_seq_len=10,
),
dict(
testcase_name="shift_by_one",
program=lib.shift_by(1, rasp.tokens),
vocab={"a", "b", "c", "d"},
test_input=list("abcd"),
expected_output=[None, "a", "b", "c"],
max_seq_len=5,
),
dict(
testcase_name="shift_by_two",
program=lib.shift_by(2, rasp.tokens),
vocab={"a", "b", "c", "d"},
test_input=list("abcd"),
expected_output=[None, None, "a", "b"],
max_seq_len=5,
),
dict(
testcase_name="detect_pattern_a",
program=lib.detect_pattern(rasp.tokens, "a"),
vocab={"a", "b", "c", "d"},
test_input=list("bacd"),
expected_output=[False, True, False, False],
max_seq_len=5,
),
dict(
testcase_name="detect_pattern_ab",
program=lib.detect_pattern(rasp.tokens, "ab"),
vocab={"a", "b"},
test_input=list("aaba"),
expected_output=[None, False, True, False],
max_seq_len=5,
),
dict(
testcase_name="detect_pattern_ab_2",
program=lib.detect_pattern(rasp.tokens, "ab"),
vocab={"a", "b"},
test_input=list("abaa"),
expected_output=[None, True, False, False],
max_seq_len=5,
),
dict(
testcase_name="detect_pattern_ab_3",
program=lib.detect_pattern(rasp.tokens, "ab"),
vocab={"a", "b"},
test_input=list("aaaa"),
expected_output=[None, False, False, False],
max_seq_len=5,
),
dict(
testcase_name="detect_pattern_abc",
program=lib.detect_pattern(rasp.tokens, "abc"),
vocab={"a", "b", "c"},
test_input=list("abcabc"),
expected_output=[None, None, True, False, False, True],
max_seq_len=6,
),
]
TEST_CASES = UNIVERSAL_TEST_CASES + [
dict(
testcase_name="reverse_1",
program=lib.make_reverse(rasp.tokens),
vocab={"a", "b", "c", "d"},
test_input=list("abcd"),
expected_output=list("dcba"),
max_seq_len=5),
dict(
testcase_name="reverse_2",
program=lib.make_reverse(rasp.tokens),
vocab={"a", "b", "c", "d"},
test_input=list("abc"),
expected_output=list("cba"),
max_seq_len=5),
dict(
testcase_name="reverse_3",
program=lib.make_reverse(rasp.tokens),
vocab={"a", "b", "c", "d"},
test_input=list("ad"),
expected_output=list("da"),
max_seq_len=5),
dict(
testcase_name="reverse_4",
program=lib.make_reverse(rasp.tokens),
vocab={"a", "b", "c", "d"},
test_input=["c"],
expected_output=["c"],
max_seq_len=5),
dict(
testcase_name="length_categorical_1",
program=rasp.categorical(lib.make_length()),
vocab={"a", "b", "c", "d"},
test_input=list("abc"),
expected_output=[3, 3, 3],
max_seq_len=3),
dict(
testcase_name="length_categorical_2",
program=rasp.categorical(lib.make_length()),
vocab={"a", "b", "c", "d"},
test_input=list("ad"),
expected_output=[2, 2],
max_seq_len=3),
dict(
testcase_name="length_categorical_3",
program=rasp.categorical(lib.make_length()),
vocab={"a", "b", "c", "d"},
test_input=["c"],
expected_output=[1],
max_seq_len=3),
dict(
testcase_name="length_numerical_1",
program=rasp.numerical(lib.make_length()),
vocab={"a", "b", "c", "d"},
test_input=list("abc"),
expected_output=[3, 3, 3],
max_seq_len=3),
dict(
testcase_name="length_numerical_2",
program=rasp.numerical(lib.make_length()),
vocab={"a", "b", "c", "d"},
test_input=list("ad"),
expected_output=[2, 2],
max_seq_len=3),
dict(
testcase_name="length_numerical_3",
program=rasp.numerical(lib.make_length()),
vocab={"a", "b", "c", "d"},
test_input=["c"],
expected_output=[1],
max_seq_len=3),
dict(
testcase_name="pair_balance_1",
program=lib.make_pair_balance(rasp.tokens, "(", ")"),
vocab={"a", "b", "c", "(", ")"},
test_input=list("a()b(c))"),
expected_output=[0.0, 1 / 2, 0.0, 0.0, 1 / 5, 1 / 6, 0.0, -1 / 8],
max_seq_len=10),
dict(
testcase_name="shuffle_dyck2_1",
program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
vocab={"(", ")", "{", "}"},
test_input=list("({)}"),
expected_output=[1, 1, 1, 1],
max_seq_len=5),
dict(
testcase_name="shuffle_dyck2_2",
program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
vocab={"(", ")", "{", "}"},
test_input=list("(){)}"),
expected_output=[0, 0, 0, 0, 0],
max_seq_len=5),
dict(
testcase_name="shuffle_dyck2_3",
program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
vocab={"(", ")", "{", "}"},
test_input=list("{}("),
expected_output=[0, 0, 0],
max_seq_len=5),
dict(
testcase_name="shuffle_dyck3_1",
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
vocab={"(", ")", "{", "}", "[", "]"},
test_input=list("({)[}]"),
expected_output=[1, 1, 1, 1, 1, 1],
max_seq_len=6),
dict(
testcase_name="shuffle_dyck3_2",
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
vocab={"(", ")", "{", "}", "[", "]"},
test_input=list("(){)}"),
expected_output=[0, 0, 0, 0, 0],
max_seq_len=6),
dict(
testcase_name="shuffle_dyck3_3",
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
vocab={"(", ")", "{", "}", "[", "]"},
test_input=list("{}[(]"),
expected_output=[0, 0, 0, 0, 0],
max_seq_len=6),
dict(
testcase_name="hist",
program=lib.make_hist(),
vocab={"a", "b", "c", "d"},
test_input=list("abac"),
expected_output=[2, 1, 2, 1],
max_seq_len=5,
),
dict(
testcase_name="sort_unique_1",
program=lib.make_sort_unique(vals=rasp.tokens, keys=rasp.tokens),
vocab={1, 2, 3, 4},
test_input=[2, 4, 3, 1],
expected_output=[1, 2, 3, 4],
max_seq_len=5),
dict(
testcase_name="sort_unique_2",
program=lib.make_sort_unique(vals=rasp.tokens, keys=1 - rasp.indices),
vocab={"a", "b", "c", "d"},
test_input=list("abcd"),
expected_output=["d", "c", "b", "a"],
max_seq_len=5),
dict(
testcase_name="sort_1",
program=lib.make_sort(
vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1),
vocab={1, 2, 3, 4},
test_input=[2, 4, 3, 1],
expected_output=[1, 2, 3, 4],
max_seq_len=5),
dict(
testcase_name="sort_2",
program=lib.make_sort(
vals=rasp.tokens, keys=1 - rasp.indices, max_seq_len=5, min_key=1),
vocab={"a", "b", "c", "d"},
test_input=list("abcd"),
expected_output=["d", "c", "b", "a"],
max_seq_len=5),
dict(
testcase_name="sort_3",
program=lib.make_sort(
vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1),
vocab={1, 2, 3, 4},
test_input=[2, 4, 1, 2],
expected_output=[1, 2, 2, 4],
max_seq_len=5),
dict(
testcase_name="sort_freq",
program=lib.make_sort_freq(max_seq_len=5),
vocab={1, 2, 3, 4},
test_input=[2, 4, 2, 1],
expected_output=[2, 2, 4, 1],
max_seq_len=5),
dict(
testcase_name="make_count_less_freq_categorical_1",
program=lib.make_count_less_freq(n=2),
vocab={"a", "b", "c", "d"},
test_input=["a", "a", "a", "b", "b", "c"],
expected_output=[3, 3, 3, 3, 3, 3],
max_seq_len=6),
dict(
testcase_name="make_count_less_freq_categorical_2",
program=lib.make_count_less_freq(n=2),
vocab={"a", "b", "c", "d"},
test_input=["a", "a", "c", "b", "b", "c"],
expected_output=[6, 6, 6, 6, 6, 6],
max_seq_len=6),
dict(
testcase_name="make_count_less_freq_numerical_1",
program=rasp.numerical(lib.make_count_less_freq(n=2)),
vocab={"a", "b", "c", "d"},
test_input=["a", "a", "a", "b", "b", "c"],
expected_output=[3, 3, 3, 3, 3, 3],
max_seq_len=6),
dict(
testcase_name="make_count_less_freq_numerical_2",
program=rasp.numerical(lib.make_count_less_freq(n=2)),
vocab={"a", "b", "c", "d"},
test_input=["a", "a", "c", "b", "b", "c"],
expected_output=[6, 6, 6, 6, 6, 6],
max_seq_len=6),
dict(
testcase_name="make_count_1",
program=lib.make_count(rasp.tokens, "a"),
vocab={"a", "b", "c"},
test_input=["a", "a", "a", "b", "b", "c"],
expected_output=[3, 3, 3, 3, 3, 3],
max_seq_len=8,
),
dict(
testcase_name="make_count_2",
program=lib.make_count(rasp.tokens, "a"),
vocab={"a", "b", "c"},
test_input=["c", "a", "b", "c"],
expected_output=[1, 1, 1, 1],
max_seq_len=8,
),
dict(
testcase_name="make_count_3",
program=lib.make_count(rasp.tokens, "a"),
vocab={"a", "b", "c"},
test_input=["b", "b", "c"],
expected_output=[0, 0, 0],
max_seq_len=8,
),
dict(
testcase_name="make_nary_sequencemap_1",
program=lib.make_nary_sequencemap(
lambda x, y, z: x + y - z, rasp.tokens, rasp.tokens, rasp.indices),
vocab={1, 2, 3},
test_input=[1, 2, 3],
expected_output=[2, 3, 4],
max_seq_len=5,
),
dict(
testcase_name="make_nary_sequencemap_2",
program=lib.make_nary_sequencemap(
lambda x, y, z: x * y / z, rasp.indices, rasp.indices, rasp.tokens),
vocab={1, 2, 3},
test_input=[1, 2, 3],
expected_output=[0, 1 / 2, 4 / 3],
max_seq_len=3,
)
]
# make_nary_sequencemap(f, *sops)
CAUSAL_TEST_CASES = UNIVERSAL_TEST_CASES + [
dict(
testcase_name="selector_width",
program=rasp.SelectorWidth(
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)),
vocab={"a", "b", "c", "d"},
test_input=list("abcd"),
expected_output=[1, 2, 3, 4],
max_seq_len=5),
]