Harsh Trivedi commited on
Commit
fba83b2
1 Parent(s): b276bf4
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - question-answering, multi-step-reasoning, multi-hop-reasoning
4
+ thumbnail: https://raw.githubusercontent.com/StonyBrookNLP/teabreac/main/teabreac_icon.png
5
+ license: cc-by-4.0
6
+ ---
7
+
8
+ # What's this?
9
+
10
+ This is one of the models reported in the paper: ["Teaching Broad Reasoning Skills for Multi-Step QA by Generating Hard Contexts".](https://arxiv.org/abs/2205.12496).
11
+
12
+ This paper proposes a procedure to synthetically generate a QA dataset, TeaBReaC, for pretraining language models for robust multi-step reasoning. Pretraining plain LMs like Bart, T5 and numerate LMs like NT5, PReasM, POET on TeaBReaC leads to improvemed downstream performance on several multi-step QA datasets. Please checkout out the paper for the details.
13
+
14
+ We release the following models:
15
+
16
+ - **A:** Base Models finetuned on target datasets: `{target_dataset}-{base_model}`
17
+ - **B:** Base models pretrained on TeaBReaC: `teabreac-{base_model}`
18
+ - **C:** Base models pretrained on TeaBReaC and then finetuned on target datasets: `teabreac-{target_dataset}-{base_model}`
19
+
20
+ The `base_model` above can be from: `bart-large`, `t5-large`, `t5-3b`, `nt5-small`, `preasm-large`.
21
+ The `target_dataset` above can be from: `drop`, `tatqa`, `iirc-gold`, `iirc-retrieved`, `numglue`.
22
+
23
+ The **A** models are only released for completeness / reproducibility. In your end application you probably just want to use either **B** or **C**.
24
+
25
+ # How to use it?
26
+
27
+ Please checkout the details in our [github repository](https://github.com/stonybrooknlp/teabreac), but in a nutshell:
28
+
29
+ ```python
30
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
31
+ from digit_tokenization import enable_digit_tokenization # digit_tokenization.py from https://github.com/stonybrooknlp/teabreac
32
+
33
+ model_name = "poet-large-tatqa"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # Fast doesn't work with digit tokenization
35
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
36
+ enable_digit_tokenization(tokenizer)
37
+ input_texts = [
38
+ "answer_me: Who scored the first touchdown of the game?" +
39
+ "context: ... Oakland would get the early lead in the first quarter as quarterback JaMarcus Russell completed a 20-yard touchdown pass to rookie wide receiver Chaz Schilens..."
40
+ # Note: some models have slightly different qn/ctxt format. See the github repo.
41
+ ]
42
+ input_ids = tokenizer(
43
+ input_texts, return_tensors="pt",
44
+ truncation=True, max_length=800,
45
+ add_special_tokens=True, padding=True,
46
+ )
47
+ generated_ids = model.generate(input_ids, min_length=1, max_length=50)
48
+ generated_predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
49
+ generated_predictions = [
50
+ tokenizer.fix_decoded_text(generated_prediction) for generated_prediction in generated_predictions
51
+ ]
52
+ # => ["Chaz Schilens"]
53
+ ```
added_tokens.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<ss>": 50309,
3
+ "[CHPD]": 50313,
4
+ "[CND]": 50310,
5
+ "[COLD]": 50303,
6
+ "[CTANS]": 50314,
7
+ "[CTCNINSINA]": 50318,
8
+ "[CTCNINTA]": 50317,
9
+ "[CTOUINSINA]": 50316,
10
+ "[CTOUINTA]": 50315,
11
+ "[DAY]": 50326,
12
+ "[DFB]": 50305,
13
+ "[DFD]": 50307,
14
+ "[DFE]": 50306,
15
+ "[MDFD]": 50308,
16
+ "[MONTH]": 50327,
17
+ "[ROWD]": 50304,
18
+ "[SCBN]": 50323,
19
+ "[SCBP]": 50322,
20
+ "[SCD]": 50321,
21
+ "[SCMN]": 50325,
22
+ "[SCMP]": 50324,
23
+ "[SIND]": 50312,
24
+ "[SNE]": 50320,
25
+ "[SNS]": 50319,
26
+ "[SOMD]": 50311,
27
+ "[YEAR]": 50328,
28
+ "are_items_different": 50267,
29
+ "are_items_same": 50266,
30
+ "arg_bool": 50295,
31
+ "arg_intersection": 50296,
32
+ "arg_maximum_date": 50280,
33
+ "arg_maximum_number": 50269,
34
+ "arg_minimum_date": 50279,
35
+ "arg_minimum_number": 50268,
36
+ "arithmetic_division": 50275,
37
+ "arithmetic_mean_single": 50272,
38
+ "arithmetic_multiplication": 50277,
39
+ "arithmetic_subtraction": 50276,
40
+ "arithmetic_sum_multiple": 50274,
41
+ "arithmetic_sum_single": 50273,
42
+ "boolean": 50265,
43
+ "compare_dates": 50284,
44
+ "compare_numbers": 50278,
45
+ "date_subtraction": 50283,
46
+ "filter_a_where_b_is_compared_to": 50292,
47
+ "filter_a_where_b_is_compared_to_date": 50288,
48
+ "filter_a_where_b_is_given_value": 50291,
49
+ "filter_a_where_b_is_max": 50289,
50
+ "filter_a_where_b_is_max_date": 50286,
51
+ "filter_a_where_b_is_min": 50290,
52
+ "filter_a_where_b_is_min_date": 50287,
53
+ "grouped_count": 50298,
54
+ "grouped_mean": 50299,
55
+ "grouped_sum": 50300,
56
+ "intersection": 50297,
57
+ "kth_highest": 50301,
58
+ "kth_lowest": 50302,
59
+ "list_subtraction": 50285,
60
+ "logical_and": 50294,
61
+ "logical_or": 50293,
62
+ "maximum_date": 50282,
63
+ "maximum_number": 50271,
64
+ "minimum_date": 50281,
65
+ "minimum_number": 50270
66
+ }
config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "SivilTaram/poet-sql-digit",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "architectures": [
8
+ "BartForConditionalGeneration"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 0,
12
+ "classif_dropout": 0.1,
13
+ "classifier_dropout": 0.0,
14
+ "d_model": 1024,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_ffn_dim": 4096,
17
+ "decoder_layerdrop": 0.0,
18
+ "decoder_layers": 12,
19
+ "decoder_start_token_id": 0,
20
+ "dropout": 0.1,
21
+ "early_stopping": true,
22
+ "encoder_attention_heads": 16,
23
+ "encoder_ffn_dim": 4096,
24
+ "encoder_layerdrop": 0.0,
25
+ "encoder_layers": 12,
26
+ "eos_token_id": 2,
27
+ "forced_eos_token_id": 2,
28
+ "gradient_checkpointing": false,
29
+ "id2label": {
30
+ "0": "LABEL_0",
31
+ "1": "LABEL_1",
32
+ "2": "LABEL_2"
33
+ },
34
+ "init_std": 0.02,
35
+ "is_encoder_decoder": true,
36
+ "label2id": {
37
+ "LABEL_0": 0,
38
+ "LABEL_1": 1,
39
+ "LABEL_2": 2
40
+ },
41
+ "max_position_embeddings": 1024,
42
+ "model_type": "bart",
43
+ "normalize_before": false,
44
+ "num_beams": 4,
45
+ "num_hidden_layers": 12,
46
+ "pad_token_id": 1,
47
+ "scale_embedding": false,
48
+ "task_specific_params": {
49
+ "summarization": {
50
+ "length_penalty": 1.0,
51
+ "max_length": 128,
52
+ "min_length": 12,
53
+ "num_beams": 4
54
+ },
55
+ "summarization_cnn": {
56
+ "length_penalty": 2.0,
57
+ "max_length": 142,
58
+ "min_length": 56,
59
+ "num_beams": 4
60
+ },
61
+ "summarization_xsum": {
62
+ "length_penalty": 1.0,
63
+ "max_length": 62,
64
+ "min_length": 11,
65
+ "num_beams": 6
66
+ }
67
+ },
68
+ "torch_dtype": "float32",
69
+ "transformers_version": "4.24.0",
70
+ "use_cache": true,
71
+ "vocab_size": 50329
72
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d135f2011957d9d5714f08fc19b291dc0c9f3e86a921b4a7a864ad6ca278703b
3
+ size 1625793025
special_tokens_map.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "0",
4
+ "1",
5
+ "2",
6
+ "3",
7
+ "4",
8
+ "5",
9
+ "6",
10
+ "7",
11
+ "8",
12
+ "9"
13
+ ],
14
+ "bos_token": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "cls_token": {
22
+ "content": "<s>",
23
+ "lstrip": false,
24
+ "normalized": true,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "eos_token": {
29
+ "content": "</s>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ },
35
+ "mask_token": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false
41
+ },
42
+ "pad_token": {
43
+ "content": "<pad>",
44
+ "lstrip": false,
45
+ "normalized": true,
46
+ "rstrip": false,
47
+ "single_word": false
48
+ },
49
+ "sep_token": {
50
+ "content": "</s>",
51
+ "lstrip": false,
52
+ "normalized": true,
53
+ "rstrip": false,
54
+ "single_word": false
55
+ },
56
+ "unk_token": {
57
+ "content": "<unk>",
58
+ "lstrip": false,
59
+ "normalized": true,
60
+ "rstrip": false,
61
+ "single_word": false
62
+ }
63
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "add_special_tokens": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "cls_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "eos_token": {
21
+ "__type": "AddedToken",
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": true,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "errors": "replace",
29
+ "mask_token": {
30
+ "__type": "AddedToken",
31
+ "content": "<mask>",
32
+ "lstrip": true,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "model_max_length": 1024,
38
+ "name_or_path": "SivilTaram/poet-sql-digit",
39
+ "pad_token": {
40
+ "__type": "AddedToken",
41
+ "content": "<pad>",
42
+ "lstrip": false,
43
+ "normalized": true,
44
+ "rstrip": false,
45
+ "single_word": false
46
+ },
47
+ "sep_token": {
48
+ "__type": "AddedToken",
49
+ "content": "</s>",
50
+ "lstrip": false,
51
+ "normalized": true,
52
+ "rstrip": false,
53
+ "single_word": false
54
+ },
55
+ "special_tokens_map_file": null,
56
+ "tokenizer_class": "BartTokenizer",
57
+ "unk_token": {
58
+ "__type": "AddedToken",
59
+ "content": "<unk>",
60
+ "lstrip": false,
61
+ "normalized": true,
62
+ "rstrip": false,
63
+ "single_word": false
64
+ }
65
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff