Text2Text Generation
Transformers
PyTorch
English
t5
Inference Endpoints
text-generation-inference
nreimers commited on
Commit
eb04306
1 Parent(s): eec6ba6
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - sentence-transformers/embedding-training-data
5
+ widget:
6
+ - text: "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
7
+
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # doc2query/msmarco-t5-small-v1
12
+
13
+ This is a [doc2query](https://arxiv.org/abs/1904.08375) model based on T5 (also known as [docT5query](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf)).
14
+
15
+ It can be used for:
16
+ - **Document expansion**: You generate for your paragraphs 20-40 queries and index the paragraphs and the generates queries in a standard BM25 index like Elasticsearch, OpenSearch, or Lucene. The generated queries help to close the lexical gap of lexical search, as the generate queries contain synonyms. Further, it re-weights words giving important words a higher weight even if they appear seldomn in a paragraph. In our [BEIR](https://arxiv.org/abs/2104.08663) paper we showed that BM25+docT5query is a powerful search engine. In the [BEIR repository](https://github.com/UKPLab/beir) we have an example how to use docT5query with Pyserini.
17
+ - **Domain Specific Training Data Generation**: It can be used to generate training data to learn an embedding model. On [SBERT.net](https://www.sbert.net/examples/unsupervised_learning/query_generation/README.html) we have an example how to use the model to generate (query, text) pairs for a given collection of unlabeled texts. These pairs can then be used to train powerful dense embedding models.
18
+
19
+ ## Usage
20
+ ```python
21
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
22
+
23
+ model_name = 'doc2query/msmarco-t5-small-v1'
24
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
25
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
26
+
27
+ text = "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
28
+
29
+
30
+ input_ids = tokenizer.encode(text, max_length=320, truncation=True, return_tensors='pt')
31
+ outputs = model.generate(
32
+ input_ids=input_ids,
33
+ max_length=64,
34
+ do_sample=True,
35
+ top_p=0.95,
36
+ num_return_sequences=5)
37
+
38
+ print("Text:")
39
+ print(text)
40
+
41
+ print("\nGenerated Queries:")
42
+ for i in range(len(outputs)):
43
+ query = tokenizer.decode(outputs[i], skip_special_tokens=True)
44
+ print(f'{i + 1}: {query}')
45
+ ```
46
+
47
+ **Note:** `model.generate()` is non-deterministic. It produces different queries each time you run it.
48
+
49
+ ## Training
50
+ This model fine-tuned [google/t5-v1_1-small](https://huggingface.co/google/t5-v1_1-small) for 31k training steps (about 4 epochs on the 500k training pairs from MS MARCO). For the training script, see the `train_script.py` in this repository.
51
+
52
+ The input-text was truncated to 320 word pieces. Output text was generated up to 64 word pieces.
53
+
54
+ This model was trained on a (query, passage) from the [MS MARCO Passage-Ranking dataset](https://github.com/microsoft/MSMARCO-Passage-Ranking).
55
+
56
+
57
+
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/t5-v1_1-small",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 1024,
7
+ "d_kv": 64,
8
+ "d_model": 512,
9
+ "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "layer_norm_epsilon": 1e-06,
16
+ "model_type": "t5",
17
+ "num_decoder_layers": 8,
18
+ "num_heads": 6,
19
+ "num_layers": 8,
20
+ "output_past": true,
21
+ "pad_token_id": 0,
22
+ "relative_attention_num_buckets": 32,
23
+ "tie_word_embeddings": false,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.11.3",
26
+ "use_cache": true,
27
+ "vocab_size": 32128
28
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e70479fde0478b478ba9ba05d071ccf4eea2bfae51215166ef94ee918837f4d0
3
+ size 307934749
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"], "model_max_length": 512, "name_or_path": "google/t5-v1_1-small", "special_tokens_map_file": "/root/.cache/huggingface/transformers/3ad6f8335c1b1ef8966245899d47dcf735abd134d21fd7d26f621fe45ac01184.c94798918c92ded6aeef2d2f0e666d2cc4145eca1aa6e1336fde07f2e13e2f46", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}
train_script.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from torch.utils.data import Dataset, IterableDataset
4
+ import gzip
5
+ import json
6
+ from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
7
+ import sys
8
+ from datetime import datetime
9
+ import torch
10
+ import random
11
+ from shutil import copyfile
12
+ import os
13
+ import wandb
14
+ import random
15
+ import re
16
+
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ handlers=[logging.StreamHandler(sys.stdout)],
22
+ )
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--model_name", default="google/t5-v1_1-base")
26
+ parser.add_argument("--train_files", required=True, nargs='+', default=[])
27
+ parser.add_argument("--epochs", default=1, type=int)
28
+ parser.add_argument("--batch_size", default=32, type=int)
29
+ parser.add_argument("--max_source_length", default=320, type=int)
30
+ parser.add_argument("--max_target_length", default=64, type=int)
31
+ parser.add_argument("--name", required=True)
32
+ parser.add_argument("--train_size", default=10*1000*1000, type=int)
33
+ parser.add_argument("--eval_size", default=10000, type=int)
34
+ parser.add_argument("--fp16", default=False, action='store_true')
35
+ args = parser.parse_args()
36
+
37
+ wandb.init(project="doc2query", name=f"{args.name}-{args.model_name}")
38
+
39
+
40
+
41
+
42
+ class PairDataset:
43
+ def __init__(self, filepath):
44
+ self.filepath = filepath
45
+ self.examples = []
46
+
47
+ def __iter__(self):
48
+ print("open", self.filepath)
49
+ with gzip.open(self.filepath, 'rt') as fIn:
50
+ for line in fIn:
51
+ example = self.get_example(json.loads(line))
52
+ if example is not None:
53
+ self.examples.append(example)
54
+ yield example
55
+
56
+ while True:
57
+ random.shuffle(self.examples)
58
+ for ex in self.examples:
59
+ yield ex
60
+
61
+
62
+ def get_example(self, raw_example):
63
+ if isinstance(raw_example, dict):
64
+ return [raw_example['query'], random.choice(raw_example['pos'])]
65
+ else:
66
+ return [raw_example[0], raw_example[1]]
67
+
68
+
69
+ class RedditTitleDataset(PairDataset):
70
+ def get_example(self, raw_example):
71
+ return [self.clean_title(raw_example['title']), raw_example['body']]
72
+
73
+
74
+ def clean_title(self, text):
75
+ text = text.replace("&amp;", "&").strip()
76
+ if text.startswith("["):
77
+ text = re.sub("^\[[a-zA-Z0-9]+\]", "", text).strip()
78
+
79
+ if text.endswith("]"):
80
+ text = re.sub("\[[a-zA-Z0-9\.]+\]$", "", text).strip()
81
+
82
+ if text.startswith("/r"):
83
+ text = re.sub("^/[a-zA-Z0-9/]+[;,: \-]+", "", text).strip()
84
+
85
+ return text
86
+
87
+
88
+ class StackExchangeTitleBodyDataset(PairDataset):
89
+ def get_example(self, raw_example):
90
+ return raw_example['texts']
91
+
92
+
93
+ class MultiDataset(IterableDataset):
94
+ def __init__(self, filepaths, num_samples):
95
+ self.num_samples = num_samples
96
+ self.datasets = []
97
+ self.data_iterators = []
98
+
99
+ for filepath in filepaths:
100
+ if 'reddit_title_text' in filepath:
101
+ dataset = RedditTitleDataset(filepath)
102
+ elif 'stackexchange_archive/jsonl' in filepath:
103
+ dataset = StackExchangeTitleBodyDataset(filepath)
104
+ else:
105
+ dataset = PairDataset(filepath)
106
+ self.datasets.append(dataset)
107
+ self.data_iterators.append(iter(dataset))
108
+
109
+ def __len__(self):
110
+ return self.num_samples
111
+
112
+ def __iter__(self):
113
+ while True:
114
+ for dataset in self.data_iterators:
115
+ yield next(dataset)
116
+
117
+ random.shuffle(self.data_iterators)
118
+
119
+ def delete_examples_cache(self):
120
+ for dataset in self.datasets:
121
+ dataset.examples = []
122
+
123
+
124
+
125
+ def main():
126
+ ############ Model
127
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
128
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
129
+
130
+ save_steps = 1000
131
+
132
+ output_dir = 'output/'+args.name+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
133
+ print("Output dir:", output_dir)
134
+
135
+ # Write self to path
136
+ os.makedirs(output_dir, exist_ok=True)
137
+
138
+ train_script_path = os.path.join(output_dir, 'train_script.py')
139
+ copyfile(__file__, train_script_path)
140
+ with open(train_script_path, 'a') as fOut:
141
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
142
+
143
+ ####
144
+
145
+ training_args = Seq2SeqTrainingArguments(
146
+ output_dir=output_dir,
147
+ fp16=args.fp16,
148
+ fp16_backend="amp",
149
+ per_device_train_batch_size=args.batch_size,
150
+ evaluation_strategy="steps",
151
+ save_steps=save_steps,
152
+ logging_steps=100,
153
+ eval_steps=save_steps, #logging_steps,
154
+ warmup_steps=1000,
155
+ save_total_limit=1,
156
+ num_train_epochs=args.epochs,
157
+ report_to="wandb",
158
+ )
159
+
160
+ ############ Arguments
161
+
162
+ ############ Load datasets
163
+
164
+
165
+ train_dataset = MultiDataset(args.train_files, args.train_size)
166
+ train_dataset_iter = iter(train_dataset)
167
+ eval_dataset = [next(train_dataset_iter) for _ in range(args.eval_size)]
168
+ train_dataset.delete_examples_cache() #Make sure dev data is no re-used for training
169
+ print("Target:", eval_dataset[0][0])
170
+ print("Input:", eval_dataset[0][1])
171
+
172
+ print("Train dataset len:", len(train_dataset))
173
+
174
+
175
+ def data_collator(examples):
176
+ targets = [row[0] for row in examples]
177
+ inputs = [row[1] for row in examples]
178
+ label_pad_token_id = -100
179
+
180
+ model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
181
+
182
+ # Setup the tokenizer for targets
183
+ with tokenizer.as_target_tokenizer():
184
+ labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
185
+
186
+ # replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
187
+ labels["input_ids"] = [
188
+ [(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
189
+ ]
190
+
191
+
192
+ model_inputs["labels"] = torch.tensor(labels["input_ids"])
193
+ return model_inputs
194
+
195
+ ## Define the trainer
196
+ trainer = Seq2SeqTrainer(
197
+ model=model,
198
+ args=training_args,
199
+ train_dataset=train_dataset,
200
+ eval_dataset=eval_dataset,
201
+ tokenizer=tokenizer,
202
+ data_collator=data_collator
203
+ )
204
+
205
+ ### Save the model
206
+ train_result = trainer.train()
207
+ trainer.save_model()
208
+
209
+
210
+ if __name__ == "__main__":
211
+ main()
212
+
213
+ # Script was called via:
214
+ #python train_hf_trainer.py --model_name google/t5-v1_1-small --train_files /home/sbert_pretrained_models/datasets/embedding-training-data/msmarco-triplets.jsonl.gz --name msmarco --train_size 2000000
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dfb7cfff430f81d2d2488d6a1a55dcdaa1fd3b830bc9bcf1697c4dcd8c1498b
3
+ size 2991