nreimers commited on
Commit
b6bc5d5
1 Parent(s): 6e727ea
1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 1024,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
README.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: sentence-similarity
3
+ tags:
4
+ - sentence-transformers
5
+ - feature-extraction
6
+ - sentence-similarity
7
+ language: en
8
+ license: apache-2.0
9
+ ---
10
+
11
+
12
+ # all-roberta-large-v1
13
+ This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 1024 dimensional dense vector space and can be used for tasks like clustering or semantic search.
14
+
15
+ ## Usage (Sentence-Transformers)
16
+ Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
17
+
18
+ ```
19
+ pip install -U sentence-transformers
20
+ ```
21
+
22
+ Then you can use the model like this:
23
+ ```python
24
+ from sentence_transformers import SentenceTransformer
25
+ sentences = ["This is an example sentence", "Each sentence is converted"]
26
+
27
+ model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
28
+ embeddings = model.encode(sentences)
29
+ print(embeddings)
30
+ ```
31
+
32
+ ## Usage (HuggingFace Transformers)
33
+ Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
34
+
35
+ ```python
36
+ from transformers import AutoTokenizer, AutoModel
37
+ import torch
38
+ import torch.nn.functional as F
39
+
40
+ #Mean Pooling - Take attention mask into account for correct averaging
41
+ def mean_pooling(model_output, attention_mask):
42
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
43
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
44
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
45
+
46
+
47
+ # Sentences we want sentence embeddings for
48
+ sentences = ['This is an example sentence', 'Each sentence is converted']
49
+
50
+ # Load model from HuggingFace Hub
51
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-roberta-large-v1')
52
+ model = AutoModel.from_pretrained('sentence-transformers/all-roberta-large-v1')
53
+
54
+ # Tokenize sentences
55
+ encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
56
+
57
+ # Compute token embeddings
58
+ with torch.no_grad():
59
+ model_output = model(**encoded_input)
60
+
61
+ # Perform pooling
62
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
63
+
64
+ # Normalize embeddings
65
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
66
+
67
+ print("Sentence embeddings:")
68
+ print(sentence_embeddings)
69
+ ```
70
+
71
+ ## Evaluation Results
72
+
73
+ For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/all-roberta-large-v1)
74
+
75
+ ------
76
+
77
+ ## Background
78
+
79
+ The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
80
+ contrastive learning objective. We used the pretrained [`roberta-large`](https://huggingface.co/roberta-large) model and fine-tuned in on a
81
+ 1B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
82
+
83
+ We developped this model during the
84
+ [Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
85
+ organized by Hugging Face. We developped this model as part of the project:
86
+ [Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
87
+
88
+ ## Intended uses
89
+
90
+ Our model is intented to be used as a sentence and short paragraph encoder. Given an input text, it ouptuts a vector which captures
91
+ the semantic information. The sentence vector may be used for information retrieval, clustering or sentence similarity tasks.
92
+
93
+ By default, input text longer than 128 word pieces is truncated.
94
+
95
+
96
+ ## Training procedure
97
+
98
+ ### Pre-training
99
+
100
+ We use the pretrained [`roberta-large`](https://huggingface.co/roberta-large). Please refer to the model card for more detailed information about the pre-training procedure.
101
+
102
+ ### Fine-tuning
103
+
104
+ We fine-tune the model using a contrastive objective. Formally, we compute the cosine similarity from each possible sentence pairs from the batch.
105
+ We then apply the cross entropy loss by comparing with true pairs.
106
+
107
+ #### Hyper parameters
108
+
109
+ We trained ou model on a TPU v3-8. We train the model during 400k steps using a batch size of 256 (32 per TPU core).
110
+ We use a learning rate warm up of 500. The sequence length was limited to 128 tokens. We used the AdamW optimizer with
111
+ a 2e-5 learning rate. The full training script is accessible in this current repository: `train_script.py`.
112
+
113
+ #### Training data
114
+
115
+ We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is above 1 billion sentences.
116
+ We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
117
+
118
+
119
+ | Dataset | Paper | Number of training tuples |
120
+ |--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
121
+ | [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
122
+ | [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts) | [paper](https://aclanthology.org/2020.acl-main.447/) | 116,288,806 |
123
+ | [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
124
+ | [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
125
+ | [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
126
+ | [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
127
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs | - | 25,316,456 |
128
+ | [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
129
+ | [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
130
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
131
+ | [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,151,414 |
132
+ | [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
133
+ | [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
134
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
135
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
136
+ | [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
137
+ | [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
138
+ | [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
139
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles) | | 304,525 |
140
+ | AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
141
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (bodies) | | 250,519 |
142
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles+bodies) | | 250,460 |
143
+ | [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
144
+ | [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
145
+ | [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
146
+ | [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
147
+ | [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
148
+ | [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
149
+ | [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
150
+ | [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
151
+ | **Total** | | **1,124,818,467** |
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "roberta-large",
3
+ "architectures": [
4
+ "RobertaForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 2,
9
+ "gradient_checkpointing": false,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.8.2",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.0.0",
4
+ "transformers": "4.6.1",
5
+ "pytorch": "1.8.1"
6
+ }
7
+ }
data_config.json ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "stackexchange_title_body/skeptics.stackexchange.com.jsonl.gz",
4
+ "lines": 10009,
5
+ "weight": 1
6
+ },
7
+ {
8
+ "name": "stackexchange_title_body/writers.stackexchange.com.jsonl.gz",
9
+ "lines": 10157,
10
+ "weight": 1
11
+ },
12
+ {
13
+ "name": "stackexchange_title_body/astronomy.stackexchange.com.jsonl.gz",
14
+ "lines": 10462,
15
+ "weight": 1
16
+ },
17
+ {
18
+ "name": "stackexchange_title_body/vi.stackexchange.com.jsonl.gz",
19
+ "lines": 10551,
20
+ "weight": 1
21
+ },
22
+ {
23
+ "name": "stackexchange_title_body/cstheory.stackexchange.com.jsonl.gz",
24
+ "lines": 10642,
25
+ "weight": 1
26
+ },
27
+ {
28
+ "name": "stackexchange_title_body/engineering.stackexchange.com.jsonl.gz",
29
+ "lines": 10753,
30
+ "weight": 1
31
+ },
32
+ {
33
+ "name": "stackexchange_title_body/french.stackexchange.com.jsonl.gz",
34
+ "lines": 10794,
35
+ "weight": 1
36
+ },
37
+ {
38
+ "name": "stackexchange_title_body/economics.stackexchange.com.jsonl.gz",
39
+ "lines": 11115,
40
+ "weight": 1
41
+ },
42
+ {
43
+ "name": "stackexchange_title_body/anime.stackexchange.com.jsonl.gz",
44
+ "lines": 11444,
45
+ "weight": 1
46
+ },
47
+ {
48
+ "name": "stackexchange_title_body/islam.stackexchange.com.jsonl.gz",
49
+ "lines": 11853,
50
+ "weight": 1
51
+ },
52
+ {
53
+ "name": "stackexchange_title_body/expressionengine.stackexchange.com.jsonl.gz",
54
+ "lines": 11866,
55
+ "weight": 1
56
+ },
57
+ {
58
+ "name": "stackexchange_title_body/politics.stackexchange.com.jsonl.gz",
59
+ "lines": 11894,
60
+ "weight": 1
61
+ },
62
+ {
63
+ "name": "stackexchange_title_body/history.stackexchange.com.jsonl.gz",
64
+ "lines": 12021,
65
+ "weight": 1
66
+ },
67
+ {
68
+ "name": "stackexchange_title_body/christianity.stackexchange.com.jsonl.gz",
69
+ "lines": 12108,
70
+ "weight": 1
71
+ },
72
+ {
73
+ "name": "stackexchange_title_body/boardgames.stackexchange.com.jsonl.gz",
74
+ "lines": 12149,
75
+ "weight": 1
76
+ },
77
+ {
78
+ "name": "stackexchange_title_body/civicrm.stackexchange.com.jsonl.gz",
79
+ "lines": 12543,
80
+ "weight": 1
81
+ },
82
+ {
83
+ "name": "stackexchange_title_body/craftcms.stackexchange.com.jsonl.gz",
84
+ "lines": 12574,
85
+ "weight": 1
86
+ },
87
+ {
88
+ "name": "stackexchange_title_body/hinduism.stackexchange.com.jsonl.gz",
89
+ "lines": 13450,
90
+ "weight": 1
91
+ },
92
+ {
93
+ "name": "stackexchange_title_body/networkengineering.stackexchange.com.jsonl.gz",
94
+ "lines": 13454,
95
+ "weight": 1
96
+ },
97
+ {
98
+ "name": "stackexchange_title_body/german.stackexchange.com.jsonl.gz",
99
+ "lines": 13950,
100
+ "weight": 1
101
+ },
102
+ {
103
+ "name": "stackexchange_title_body/philosophy.stackexchange.com.jsonl.gz",
104
+ "lines": 14829,
105
+ "weight": 1
106
+ },
107
+ {
108
+ "name": "stackexchange_title_body/gardening.stackexchange.com.jsonl.gz",
109
+ "lines": 15136,
110
+ "weight": 1
111
+ },
112
+ {
113
+ "name": "stackexchange_title_body/space.stackexchange.com.jsonl.gz",
114
+ "lines": 15142,
115
+ "weight": 1
116
+ },
117
+ {
118
+ "name": "stackexchange_title_body/bicycles.stackexchange.com.jsonl.gz",
119
+ "lines": 16353,
120
+ "weight": 1
121
+ },
122
+ {
123
+ "name": "stackexchange_title_body/quant.stackexchange.com.jsonl.gz",
124
+ "lines": 17261,
125
+ "weight": 1
126
+ },
127
+ {
128
+ "name": "stackexchange_title_body/puzzling.stackexchange.com.jsonl.gz",
129
+ "lines": 17851,
130
+ "weight": 1
131
+ },
132
+ {
133
+ "name": "stackexchange_title_body/law.stackexchange.com.jsonl.gz",
134
+ "lines": 17941,
135
+ "weight": 1
136
+ },
137
+ {
138
+ "name": "stackexchange_title_body/arduino.stackexchange.com.jsonl.gz",
139
+ "lines": 19553,
140
+ "weight": 1
141
+ },
142
+ {
143
+ "name": "stackexchange_title_body/aviation.stackexchange.com.jsonl.gz",
144
+ "lines": 20139,
145
+ "weight": 1
146
+ },
147
+ {
148
+ "name": "stackexchange_title_body/softwarerecs.stackexchange.com.jsonl.gz",
149
+ "lines": 20142,
150
+ "weight": 1
151
+ },
152
+ {
153
+ "name": "stackexchange_title_body/movies.stackexchange.com.jsonl.gz",
154
+ "lines": 20181,
155
+ "weight": 1
156
+ },
157
+ {
158
+ "name": "stackexchange_title_body/music.stackexchange.com.jsonl.gz",
159
+ "lines": 20636,
160
+ "weight": 1
161
+ },
162
+ {
163
+ "name": "stackexchange_title_body/emacs.stackexchange.com.jsonl.gz",
164
+ "lines": 21055,
165
+ "weight": 1
166
+ },
167
+ {
168
+ "name": "stackexchange_title_body/dsp.stackexchange.com.jsonl.gz",
169
+ "lines": 21252,
170
+ "weight": 1
171
+ },
172
+ {
173
+ "name": "flickr30k_captions.jsonl.gz",
174
+ "lines": 317695,
175
+ "weight": 1
176
+ },
177
+ {
178
+ "name": "coco_captions.jsonl.gz",
179
+ "lines": 828395,
180
+ "weight": 1
181
+ },
182
+ {
183
+ "name": "codesearchnet.jsonl.gz",
184
+ "lines": 1151414,
185
+ "weight": 1
186
+ },
187
+ {
188
+ "name": "stackexchange_title_body/japanese.stackexchange.com.jsonl.gz",
189
+ "lines": 22056,
190
+ "weight": 2
191
+ },
192
+ {
193
+ "name": "stackexchange_title_body/mechanics.stackexchange.com.jsonl.gz",
194
+ "lines": 22868,
195
+ "weight": 2
196
+ },
197
+ {
198
+ "name": "stackexchange_title_body/crypto.stackexchange.com.jsonl.gz",
199
+ "lines": 23231,
200
+ "weight": 2
201
+ },
202
+ {
203
+ "name": "stackexchange_title_body/cooking.stackexchange.com.jsonl.gz",
204
+ "lines": 23705,
205
+ "weight": 2
206
+ },
207
+ {
208
+ "name": "stackexchange_title_body/photo.stackexchange.com.jsonl.gz",
209
+ "lines": 23753,
210
+ "weight": 2
211
+ },
212
+ {
213
+ "name": "stackexchange_title_body/workplace.stackexchange.com.jsonl.gz",
214
+ "lines": 24189,
215
+ "weight": 2
216
+ },
217
+ {
218
+ "name": "stackexchange_title_body/biology.stackexchange.com.jsonl.gz",
219
+ "lines": 24447,
220
+ "weight": 2
221
+ },
222
+ {
223
+ "name": "stackexchange_title_body/bitcoin.stackexchange.com.jsonl.gz",
224
+ "lines": 25374,
225
+ "weight": 2
226
+ },
227
+ {
228
+ "name": "stackexchange_title_body/worldbuilding.stackexchange.com.jsonl.gz",
229
+ "lines": 26763,
230
+ "weight": 2
231
+ },
232
+ {
233
+ "name": "stackexchange_title_body/datascience.stackexchange.com.jsonl.gz",
234
+ "lines": 27397,
235
+ "weight": 2
236
+ },
237
+ {
238
+ "name": "stackexchange_title_body/ux.stackexchange.com.jsonl.gz",
239
+ "lines": 29403,
240
+ "weight": 2
241
+ },
242
+ {
243
+ "name": "stackexchange_title_body/webapps.stackexchange.com.jsonl.gz",
244
+ "lines": 29697,
245
+ "weight": 2
246
+ },
247
+ {
248
+ "name": "stackexchange_title_body/graphicdesign.stackexchange.com.jsonl.gz",
249
+ "lines": 30233,
250
+ "weight": 2
251
+ },
252
+ {
253
+ "name": "stackexchange_title_body/raspberrypi.stackexchange.com.jsonl.gz",
254
+ "lines": 30625,
255
+ "weight": 2
256
+ },
257
+ {
258
+ "name": "stackexchange_title_body/money.stackexchange.com.jsonl.gz",
259
+ "lines": 32021,
260
+ "weight": 2
261
+ },
262
+ {
263
+ "name": "stackexchange_title_body/judaism.stackexchange.com.jsonl.gz",
264
+ "lines": 32028,
265
+ "weight": 2
266
+ },
267
+ {
268
+ "name": "stackexchange_title_body/ethereum.stackexchange.com.jsonl.gz",
269
+ "lines": 32760,
270
+ "weight": 2
271
+ },
272
+ {
273
+ "name": "stackexchange_title_body/academia.stackexchange.com.jsonl.gz",
274
+ "lines": 34331,
275
+ "weight": 2
276
+ },
277
+ {
278
+ "name": "stackexchange_title_body/chemistry.stackexchange.com.jsonl.gz",
279
+ "lines": 34506,
280
+ "weight": 2
281
+ },
282
+ {
283
+ "name": "stackexchange_title_body/webmasters.stackexchange.com.jsonl.gz",
284
+ "lines": 34559,
285
+ "weight": 2
286
+ },
287
+ {
288
+ "name": "stackexchange_title_body/meta.stackoverflow.com.jsonl.gz",
289
+ "lines": 36456,
290
+ "weight": 2
291
+ },
292
+ {
293
+ "name": "stackexchange_title_body/cs.stackexchange.com.jsonl.gz",
294
+ "lines": 38314,
295
+ "weight": 2
296
+ },
297
+ {
298
+ "name": "stackexchange_title_body/travel.stackexchange.com.jsonl.gz",
299
+ "lines": 41227,
300
+ "weight": 2
301
+ },
302
+ {
303
+ "name": "stackexchange_title_body/rpg.stackexchange.com.jsonl.gz",
304
+ "lines": 42303,
305
+ "weight": 2
306
+ },
307
+ {
308
+ "name": "stackexchange_title_body/codereview.stackexchange.com.jsonl.gz",
309
+ "lines": 45765,
310
+ "weight": 3
311
+ },
312
+ {
313
+ "name": "stackexchange_title_body/gamedev.stackexchange.com.jsonl.gz",
314
+ "lines": 46485,
315
+ "weight": 3
316
+ },
317
+ {
318
+ "name": "stackexchange_title_body/android.stackexchange.com.jsonl.gz",
319
+ "lines": 51608,
320
+ "weight": 3
321
+ },
322
+ {
323
+ "name": "stackexchange_title_body/softwareengineering.stackexchange.com.jsonl.gz",
324
+ "lines": 53942,
325
+ "weight": 3
326
+ },
327
+ {
328
+ "name": "stackexchange_title_body/security.stackexchange.com.jsonl.gz",
329
+ "lines": 58000,
330
+ "weight": 3
331
+ },
332
+ {
333
+ "name": "stackexchange_title_body/diy.stackexchange.com.jsonl.gz",
334
+ "lines": 60083,
335
+ "weight": 3
336
+ },
337
+ {
338
+ "name": "stackexchange_title_body/scifi.stackexchange.com.jsonl.gz",
339
+ "lines": 61528,
340
+ "weight": 3
341
+ },
342
+ {
343
+ "name": "stackexchange_title_body/mathematica.stackexchange.com.jsonl.gz",
344
+ "lines": 73131,
345
+ "weight": 4
346
+ },
347
+ {
348
+ "name": "TriviaQA_pairs.jsonl.gz",
349
+ "lines": 73346,
350
+ "weight": 4
351
+ },
352
+ {
353
+ "name": "stackexchange_title_body/drupal.stackexchange.com.jsonl.gz",
354
+ "lines": 79717,
355
+ "weight": 4
356
+ },
357
+ {
358
+ "name": "stackexchange_title_body/blender.stackexchange.com.jsonl.gz",
359
+ "lines": 80766,
360
+ "weight": 4
361
+ },
362
+ {
363
+ "name": "stackexchange_title_body/dba.stackexchange.com.jsonl.gz",
364
+ "lines": 81871,
365
+ "weight": 4
366
+ },
367
+ {
368
+ "name": "stackexchange_title_body/ell.stackexchange.com.jsonl.gz",
369
+ "lines": 83271,
370
+ "weight": 4
371
+ },
372
+ {
373
+ "name": "stackexchange_title_body/meta.stackexchange.com.jsonl.gz",
374
+ "lines": 83510,
375
+ "weight": 4
376
+ },
377
+ {
378
+ "name": "squad_pairs.jsonl.gz",
379
+ "lines": 87599,
380
+ "weight": 5
381
+ },
382
+ {
383
+ "name": "stackexchange_title_body/gaming.stackexchange.com.jsonl.gz",
384
+ "lines": 88912,
385
+ "weight": 5
386
+ },
387
+ {
388
+ "name": "stackexchange_title_body/sharepoint.stackexchange.com.jsonl.gz",
389
+ "lines": 94011,
390
+ "weight": 5
391
+ },
392
+ {
393
+ "name": "stackexchange_title_body/magento.stackexchange.com.jsonl.gz",
394
+ "lines": 99991,
395
+ "weight": 5
396
+ },
397
+ {
398
+ "name": "NQ-train_pairs.jsonl.gz",
399
+ "lines": 100231,
400
+ "weight": 5
401
+ },
402
+ {
403
+ "name": "stackexchange_title_body/wordpress.stackexchange.com.jsonl.gz",
404
+ "lines": 100474,
405
+ "weight": 5
406
+ },
407
+ {
408
+ "name": "SimpleWiki.jsonl.gz",
409
+ "lines": 102225,
410
+ "weight": 5
411
+ },
412
+ {
413
+ "name": "quora_duplicates_triplets.jsonl.gz",
414
+ "lines": 103663,
415
+ "weight": 5
416
+ },
417
+ {
418
+ "name": "stackexchange_title_body/salesforce.stackexchange.com.jsonl.gz",
419
+ "lines": 105260,
420
+ "weight": 5
421
+ },
422
+ {
423
+ "name": "stackexchange_title_body/english.stackexchange.com.jsonl.gz",
424
+ "lines": 109522,
425
+ "weight": 6
426
+ },
427
+ {
428
+ "name": "stackexchange_title_body/apple.stackexchange.com.jsonl.gz",
429
+ "lines": 110622,
430
+ "weight": 6
431
+ },
432
+ {
433
+ "name": "altlex.jsonl.gz",
434
+ "lines": 112696,
435
+ "weight": 6
436
+ },
437
+ {
438
+ "name": "stackexchange_title_body/mathoverflow.net.jsonl.gz",
439
+ "lines": 120851,
440
+ "weight": 6
441
+ },
442
+ {
443
+ "name": "wikihow.jsonl.gz",
444
+ "lines": 128542,
445
+ "weight": 6
446
+ },
447
+ {
448
+ "name": "stackexchange_title_body/gis.stackexchange.com.jsonl.gz",
449
+ "lines": 131000,
450
+ "weight": 7
451
+ },
452
+ {
453
+ "name": "stackexchange_title_body/electronics.stackexchange.com.jsonl.gz",
454
+ "lines": 143582,
455
+ "weight": 7
456
+ },
457
+ {
458
+ "name": "stackexchange_title_body/physics.stackexchange.com.jsonl.gz",
459
+ "lines": 173307,
460
+ "weight": 9
461
+ },
462
+ {
463
+ "name": "stackexchange_title_body/stats.stackexchange.com.jsonl.gz",
464
+ "lines": 173466,
465
+ "weight": 9
466
+ },
467
+ {
468
+ "name": "sentence-compression.jsonl.gz",
469
+ "lines": 180000,
470
+ "weight": 9
471
+ },
472
+ {
473
+ "name": "stackexchange_title_body/unix.stackexchange.com.jsonl.gz",
474
+ "lines": 185997,
475
+ "weight": 9
476
+ },
477
+ {
478
+ "name": "stackexchange_title_body/tex.stackexchange.com.jsonl.gz",
479
+ "lines": 202954,
480
+ "weight": 10
481
+ },
482
+ {
483
+ "name": "stackexchange_duplicate_questions_title-body_title-body.jsonl.gz",
484
+ "lines": 250460,
485
+ "weight": 12
486
+ },
487
+ {
488
+ "name": "stackexchange_duplicate_questions_body_body.jsonl.gz",
489
+ "lines": 250519,
490
+ "weight": 12
491
+ },
492
+ {
493
+ "name": "stackexchange_title_body/serverfault.com.jsonl.gz",
494
+ "lines": 270904,
495
+ "weight": 13
496
+ },
497
+ {
498
+ "name": "AllNLI.jsonl.gz",
499
+ "lines": 277230,
500
+ "weight": 13
501
+ },
502
+ {
503
+ "name": "stackexchange_duplicate_questions_title_title.jsonl.gz",
504
+ "lines": 304525,
505
+ "weight": 15
506
+ },
507
+ {
508
+ "name": "eli5_question_answer.jsonl.gz",
509
+ "lines": 325475,
510
+ "weight": 16
511
+ },
512
+ {
513
+ "name": "specter_train_triples.jsonl.gz",
514
+ "lines": 684100,
515
+ "weight": 16
516
+ },
517
+ {
518
+ "name": "stackexchange_title_body/askubuntu.com.jsonl.gz",
519
+ "lines": 347925,
520
+ "weight": 17
521
+ },
522
+ {
523
+ "name": "stackexchange_title_body/superuser.com.jsonl.gz",
524
+ "lines": 435463,
525
+ "weight": 21
526
+ },
527
+ {
528
+ "name": "stackexchange_title_body/small_stackexchanges.jsonl.gz",
529
+ "lines": 448146,
530
+ "weight": 21
531
+ },
532
+ {
533
+ "name": "S2ORC_title_abstract.jsonl.gz",
534
+ "lines": 41769185,
535
+ "weight": 23
536
+ },
537
+ {
538
+ "name": "S2ORC_citation_pairs.jsonl.gz",
539
+ "lines": 52603982,
540
+ "weight": 12
541
+ },
542
+ {
543
+ "name": "S2ORC_citation_pairs_abstract.jsonl.gz",
544
+ "lines": 116288806,
545
+ "weight": 12
546
+ },
547
+ {
548
+ "name": "PAQ_pairs.jsonl.gz",
549
+ "lines": 64371441,
550
+ "weight": 23
551
+ },
552
+ {
553
+ "name": "WikiAnswers_pairs.jsonl.gz",
554
+ "lines": 77427422,
555
+ "weight": 23
556
+ },
557
+ {
558
+ "name": "searchQA_question_top5_snippets_merged.jsonl.gz",
559
+ "lines": 582261,
560
+ "weight": 28
561
+ },
562
+ {
563
+ "name": "yahoo_answers_title_question.jsonl.gz",
564
+ "lines": 659896,
565
+ "weight": 31
566
+ },
567
+ {
568
+ "name": "yahoo_answers_question_answer.jsonl.gz",
569
+ "lines": 681164,
570
+ "weight": 32
571
+ },
572
+ {
573
+ "name": "yahoo_answers_title_answer.jsonl.gz",
574
+ "lines": 1198260,
575
+ "weight": 47
576
+ },
577
+ {
578
+ "name": "stackexchange_title_body/math.stackexchange.com.jsonl.gz",
579
+ "lines": 1338443,
580
+ "weight": 47
581
+ },
582
+ {
583
+ "name": "gooaq_pairs.jsonl.gz",
584
+ "lines": 3012496,
585
+ "weight": 47
586
+ },
587
+ {
588
+ "name": "msmarco-query_passage_negative.jsonl.gz",
589
+ "lines": 9144553,
590
+ "weight": 47
591
+ },
592
+ {
593
+ "name": "stackexchange_title_body/stackoverflow.com-Posts.jsonl.gz",
594
+ "lines": 18562443,
595
+ "weight": 47
596
+ },
597
+ {"name": "reddit/reddit_2015.jsonl.gz", "weight": 50},
598
+ {"name": "reddit/reddit_2016.jsonl.gz", "weight": 50},
599
+ {"name": "reddit/reddit_2017.jsonl.gz", "weight": 50},
600
+ {"name": "reddit/reddit_2018.jsonl.gz", "weight": 50}
601
+ ]
merges.txt ADDED
The diff for this file is too large to render. See raw diff
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29bb8f3e407eaaa38e2675111fa56cf6f56cb56aab4f2257477fbb1467c07747
3
+ size 1421566897
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ {
2
+ "max_seq_length": 128,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "roberta-large", "tokenizer_class": "RobertaTokenizer"}
train_script.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, normalize=True):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ self.model = AutoModel.from_pretrained(model_name)
45
+ self.normalize = normalize
46
+ self.tokenizer = tokenizer
47
+
48
+ def forward(self, **kwargs):
49
+ model_output = self.model(**kwargs)
50
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
51
+ if self.normalize:
52
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53
+
54
+ return embeddings
55
+
56
+ def mean_pooling(self, model_output, attention_mask):
57
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
58
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
59
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
60
+
61
+ def save_pretrained(self, output_path):
62
+ if xm.is_master_ordinal():
63
+ self.tokenizer.save_pretrained(output_path)
64
+ self.model.config.save_pretrained(output_path)
65
+
66
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
67
+
68
+
69
+
70
+
71
+ def train_function(index, args, queue):
72
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
73
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer)
74
+
75
+
76
+ ### Train Loop
77
+ device = xm.xla_device()
78
+ model = model.to(device)
79
+
80
+ # Instantiate optimizer
81
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
82
+
83
+ lr_scheduler = get_linear_schedule_with_warmup(
84
+ optimizer=optimizer,
85
+ num_warmup_steps=500,
86
+ num_training_steps=args.steps,
87
+ )
88
+
89
+ # Now we train the model
90
+ cross_entropy_loss = nn.CrossEntropyLoss()
91
+ max_grad_norm = 1
92
+
93
+ model.train()
94
+
95
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
96
+ #### Get the batch data
97
+ batch = queue.get()
98
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
99
+
100
+
101
+ if len(batch[0]) == 2: #(anchor, positive)
102
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
103
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
104
+
105
+ ### Compute embeddings
106
+ embeddings_a = model(**text1.to(device))
107
+ embeddings_b = model(**text2.to(device))
108
+
109
+ ### Gather all embedings
110
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
111
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
112
+
113
+ ### Compute similarity scores 512 x 512
114
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
115
+
116
+ ### Compute cross-entropy loss
117
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
118
+
119
+ ## Symmetric loss as in CLIP
120
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
121
+
122
+ else: #(anchor, positive, negative)
123
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
124
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
125
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
126
+
127
+ embeddings_a = model(**text1.to(device))
128
+ embeddings_b1 = model(**text2.to(device))
129
+ embeddings_b2 = model(**text3.to(device))
130
+
131
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
132
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
133
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
134
+
135
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
136
+
137
+ ### Compute similarity scores 512 x 1024
138
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
139
+
140
+ ### Compute cross-entropy loss
141
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
142
+
143
+ ## One-way loss
144
+ loss = cross_entropy_loss(scores, labels)
145
+
146
+
147
+ # Backward pass
148
+ optimizer.zero_grad()
149
+ loss.backward()
150
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
151
+
152
+ xm.optimizer_step(optimizer, barrier=True)
153
+ lr_scheduler.step()
154
+
155
+
156
+ #Save model
157
+ if (global_step+1) % args.save_steps == 0:
158
+ output_path = os.path.join(args.output, str(global_step+1))
159
+ xm.master_print("save model: "+output_path)
160
+ model.save_pretrained(output_path)
161
+
162
+
163
+ output_path = os.path.join(args.output, "final")
164
+ xm.master_print("save model final: "+ output_path)
165
+ model.save_pretrained(output_path)
166
+
167
+
168
+ def produce_data(args, queue, filepaths, dataset_indices):
169
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
170
+ size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
171
+ num_same_dataset = int(size_per_dataset / args.batch_size)
172
+ print("producer", "global_batch_size", global_batch_size)
173
+ print("producer", "size_per_dataset", size_per_dataset)
174
+ print("producer", "num_same_dataset", num_same_dataset)
175
+
176
+ datasets = []
177
+ for filepath in filepaths:
178
+ if "reddit_" in filepath: #Special dataset class for Reddit files
179
+ data_obj = RedditDataset(filepath)
180
+ else:
181
+ data_obj = Dataset(filepath)
182
+ datasets.append(iter(data_obj))
183
+
184
+ # Store if dataset is in a 2 col or 3 col format
185
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
186
+
187
+ while True:
188
+ texts_in_batch = set()
189
+ batch_format = None #2 vs 3 col format for this batch
190
+
191
+ #Add data from several sub datasets
192
+ for _ in range(args.datasets_per_batch):
193
+ valid_dataset = False #Check that datasets have the same 2/3 col format
194
+ while not valid_dataset:
195
+ data_idx = random.choice(dataset_indices)
196
+ if batch_format is None:
197
+ batch_format = num_cols[data_idx]
198
+ valid_dataset = True
199
+ else: #Check that this dataset has the same format
200
+ valid_dataset = (batch_format == num_cols[data_idx])
201
+
202
+ #Get data from this dataset
203
+ dataset = datasets[data_idx]
204
+ for _ in range(num_same_dataset):
205
+ for _ in range(args.nprocs):
206
+ batch_device = [] #A batch for one device
207
+ while len(batch_device) < args.batch_size:
208
+ sample = next(dataset)
209
+ in_batch = False
210
+ for text in sample:
211
+ if text in texts_in_batch:
212
+ in_batch = True
213
+ break
214
+
215
+ if not in_batch:
216
+ for text in sample:
217
+ texts_in_batch.add(text)
218
+ batch_device.append(sample)
219
+
220
+ queue.put(batch_device)
221
+
222
+
223
+ class RedditDataset:
224
+ """
225
+ A class that handles the reddit data files
226
+ """
227
+ def __init__(self, filepath):
228
+ self.filepath = filepath
229
+
230
+ def __iter__(self):
231
+ while True:
232
+ with gzip.open(self.filepath, "rt") as fIn:
233
+ for line in fIn:
234
+ data = json.loads(line)
235
+
236
+ if "response" in data and "context" in data:
237
+ yield [data["response"], data["context"]]
238
+
239
+ class Dataset:
240
+ """
241
+ A class that handles one dataset
242
+ """
243
+ def __init__(self, filepath):
244
+ self.filepath = filepath
245
+
246
+ def __iter__(self):
247
+ max_dataset_size = 10*1000*1000 #Cache small datasets in memory
248
+ dataset = []
249
+ data_format = None
250
+
251
+ while dataset is None or len(dataset) == 0:
252
+ with gzip.open(self.filepath, "rt") as fIn:
253
+ for line in fIn:
254
+ data = json.loads(line)
255
+ if isinstance(data, dict):
256
+ data = data['texts']
257
+
258
+ if data_format is None:
259
+ data_format = len(data)
260
+
261
+ #Ensure that all entries are of the same 2/3 col format
262
+ assert len(data) == data_format
263
+
264
+ if dataset is not None:
265
+ dataset.append(data)
266
+ if len(dataset) >= max_dataset_size:
267
+ dataset = None
268
+
269
+ yield data
270
+
271
+ # Data loaded. Now stream to the queue
272
+ # Shuffle for each epoch
273
+ while True:
274
+ random.shuffle(dataset)
275
+ for data in dataset:
276
+ yield data
277
+
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser()
282
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
283
+ parser.add_argument('--steps', type=int, default=2000)
284
+ parser.add_argument('--save_steps', type=int, default=10000)
285
+ parser.add_argument('--batch_size', type=int, default=64)
286
+ parser.add_argument('--max_length', type=int, default=128)
287
+ parser.add_argument('--nprocs', type=int, default=8)
288
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
289
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
290
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
291
+ parser.add_argument('data_config', help="A data_config.json file")
292
+ parser.add_argument('output')
293
+ args = parser.parse_args()
294
+
295
+ # Ensure global batch size is divisble by data_sample_size
296
+ assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
297
+
298
+ logging.info("Output: "+args.output)
299
+ if os.path.exists(args.output):
300
+ print("Output folder already exists.")
301
+ input("Continue?")
302
+
303
+ # Write train script to output path
304
+ os.makedirs(args.output, exist_ok=True)
305
+
306
+ data_config_path = os.path.join(args.output, 'data_config.json')
307
+ copyfile(args.data_config, data_config_path)
308
+
309
+ train_script_path = os.path.join(args.output, 'train_script.py')
310
+ copyfile(__file__, train_script_path)
311
+ with open(train_script_path, 'a') as fOut:
312
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
313
+
314
+
315
+
316
+ #Load data config
317
+ with open(args.data_config) as fIn:
318
+ data_config = json.load(fIn)
319
+
320
+ queue = mp.Queue(maxsize=100*args.nprocs)
321
+
322
+ filepaths = []
323
+ dataset_indices = []
324
+ for idx, data in enumerate(data_config):
325
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
326
+ dataset_indices.extend([idx]*data['weight'])
327
+
328
+ # Start producer
329
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
330
+ p.start()
331
+
332
+ # Run training
333
+ print("Start processes:", args.nprocs)
334
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
335
+ print("Training done")
336
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
337
+ print("With 'pkill python' you can kill all remaining python processes")
338
+ p.kill()
339
+ exit()
340
+
341
+
342
+
343
+ # Script was called via:
344
+ #python train_many_data_files_v2.py --steps 1000000 --batch_size 32 --model roberta-large train_data_configs/all_datasets_v3.json output/all_datasets_v3_roberta-large
vocab.json ADDED
The diff for this file is too large to render. See raw diff