nreimers commited on
Commit
bf953af
·
1 Parent(s): d8dd501
1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: sentence-similarity
3
+ tags:
4
+ - sentence-transformers
5
+ - feature-extraction
6
+ - sentence-similarity
7
+ ---
8
+
9
+ # all_datasets_v3_mpnet-base
10
+
11
+ This is a mpnet-base model trained on all the dataset of the 1B+ train corpus. It was trained with the v3 setup. See data_config.json and train_script.py in this respository how the model was trained and which datasets have been used.
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "output/all_datasets_v3_mpnet-base/120000",
3
+ "architectures": [
4
+ "MPNetForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "mpnet",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "relative_attention_num_buckets": 32,
21
+ "transformers_version": "4.8.2",
22
+ "vocab_size": 30527
23
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.0.0",
4
+ "transformers": "4.6.1",
5
+ "pytorch": "1.8.1"
6
+ }
7
+ }
data_config.json ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "stackexchange_title_body/skeptics.stackexchange.com.jsonl.gz",
4
+ "lines": 10009,
5
+ "weight": 1
6
+ },
7
+ {
8
+ "name": "stackexchange_title_body/writers.stackexchange.com.jsonl.gz",
9
+ "lines": 10157,
10
+ "weight": 1
11
+ },
12
+ {
13
+ "name": "stackexchange_title_body/astronomy.stackexchange.com.jsonl.gz",
14
+ "lines": 10462,
15
+ "weight": 1
16
+ },
17
+ {
18
+ "name": "stackexchange_title_body/vi.stackexchange.com.jsonl.gz",
19
+ "lines": 10551,
20
+ "weight": 1
21
+ },
22
+ {
23
+ "name": "stackexchange_title_body/cstheory.stackexchange.com.jsonl.gz",
24
+ "lines": 10642,
25
+ "weight": 1
26
+ },
27
+ {
28
+ "name": "stackexchange_title_body/engineering.stackexchange.com.jsonl.gz",
29
+ "lines": 10753,
30
+ "weight": 1
31
+ },
32
+ {
33
+ "name": "stackexchange_title_body/french.stackexchange.com.jsonl.gz",
34
+ "lines": 10794,
35
+ "weight": 1
36
+ },
37
+ {
38
+ "name": "stackexchange_title_body/economics.stackexchange.com.jsonl.gz",
39
+ "lines": 11115,
40
+ "weight": 1
41
+ },
42
+ {
43
+ "name": "stackexchange_title_body/anime.stackexchange.com.jsonl.gz",
44
+ "lines": 11444,
45
+ "weight": 1
46
+ },
47
+ {
48
+ "name": "stackexchange_title_body/islam.stackexchange.com.jsonl.gz",
49
+ "lines": 11853,
50
+ "weight": 1
51
+ },
52
+ {
53
+ "name": "stackexchange_title_body/expressionengine.stackexchange.com.jsonl.gz",
54
+ "lines": 11866,
55
+ "weight": 1
56
+ },
57
+ {
58
+ "name": "stackexchange_title_body/politics.stackexchange.com.jsonl.gz",
59
+ "lines": 11894,
60
+ "weight": 1
61
+ },
62
+ {
63
+ "name": "stackexchange_title_body/history.stackexchange.com.jsonl.gz",
64
+ "lines": 12021,
65
+ "weight": 1
66
+ },
67
+ {
68
+ "name": "stackexchange_title_body/christianity.stackexchange.com.jsonl.gz",
69
+ "lines": 12108,
70
+ "weight": 1
71
+ },
72
+ {
73
+ "name": "stackexchange_title_body/boardgames.stackexchange.com.jsonl.gz",
74
+ "lines": 12149,
75
+ "weight": 1
76
+ },
77
+ {
78
+ "name": "stackexchange_title_body/civicrm.stackexchange.com.jsonl.gz",
79
+ "lines": 12543,
80
+ "weight": 1
81
+ },
82
+ {
83
+ "name": "stackexchange_title_body/craftcms.stackexchange.com.jsonl.gz",
84
+ "lines": 12574,
85
+ "weight": 1
86
+ },
87
+ {
88
+ "name": "stackexchange_title_body/hinduism.stackexchange.com.jsonl.gz",
89
+ "lines": 13450,
90
+ "weight": 1
91
+ },
92
+ {
93
+ "name": "stackexchange_title_body/networkengineering.stackexchange.com.jsonl.gz",
94
+ "lines": 13454,
95
+ "weight": 1
96
+ },
97
+ {
98
+ "name": "stackexchange_title_body/german.stackexchange.com.jsonl.gz",
99
+ "lines": 13950,
100
+ "weight": 1
101
+ },
102
+ {
103
+ "name": "stackexchange_title_body/philosophy.stackexchange.com.jsonl.gz",
104
+ "lines": 14829,
105
+ "weight": 1
106
+ },
107
+ {
108
+ "name": "stackexchange_title_body/gardening.stackexchange.com.jsonl.gz",
109
+ "lines": 15136,
110
+ "weight": 1
111
+ },
112
+ {
113
+ "name": "stackexchange_title_body/space.stackexchange.com.jsonl.gz",
114
+ "lines": 15142,
115
+ "weight": 1
116
+ },
117
+ {
118
+ "name": "stackexchange_title_body/bicycles.stackexchange.com.jsonl.gz",
119
+ "lines": 16353,
120
+ "weight": 1
121
+ },
122
+ {
123
+ "name": "stackexchange_title_body/quant.stackexchange.com.jsonl.gz",
124
+ "lines": 17261,
125
+ "weight": 1
126
+ },
127
+ {
128
+ "name": "stackexchange_title_body/puzzling.stackexchange.com.jsonl.gz",
129
+ "lines": 17851,
130
+ "weight": 1
131
+ },
132
+ {
133
+ "name": "stackexchange_title_body/law.stackexchange.com.jsonl.gz",
134
+ "lines": 17941,
135
+ "weight": 1
136
+ },
137
+ {
138
+ "name": "stackexchange_title_body/arduino.stackexchange.com.jsonl.gz",
139
+ "lines": 19553,
140
+ "weight": 1
141
+ },
142
+ {
143
+ "name": "stackexchange_title_body/aviation.stackexchange.com.jsonl.gz",
144
+ "lines": 20139,
145
+ "weight": 1
146
+ },
147
+ {
148
+ "name": "stackexchange_title_body/softwarerecs.stackexchange.com.jsonl.gz",
149
+ "lines": 20142,
150
+ "weight": 1
151
+ },
152
+ {
153
+ "name": "stackexchange_title_body/movies.stackexchange.com.jsonl.gz",
154
+ "lines": 20181,
155
+ "weight": 1
156
+ },
157
+ {
158
+ "name": "stackexchange_title_body/music.stackexchange.com.jsonl.gz",
159
+ "lines": 20636,
160
+ "weight": 1
161
+ },
162
+ {
163
+ "name": "stackexchange_title_body/emacs.stackexchange.com.jsonl.gz",
164
+ "lines": 21055,
165
+ "weight": 1
166
+ },
167
+ {
168
+ "name": "stackexchange_title_body/dsp.stackexchange.com.jsonl.gz",
169
+ "lines": 21252,
170
+ "weight": 1
171
+ },
172
+ {
173
+ "name": "flickr30k_captions.jsonl.gz",
174
+ "lines": 317695,
175
+ "weight": 1
176
+ },
177
+ {
178
+ "name": "coco_captions.jsonl.gz",
179
+ "lines": 828395,
180
+ "weight": 1
181
+ },
182
+ {
183
+ "name": "codesearchnet.jsonl.gz",
184
+ "lines": 1151414,
185
+ "weight": 1
186
+ },
187
+ {
188
+ "name": "stackexchange_title_body/japanese.stackexchange.com.jsonl.gz",
189
+ "lines": 22056,
190
+ "weight": 2
191
+ },
192
+ {
193
+ "name": "stackexchange_title_body/mechanics.stackexchange.com.jsonl.gz",
194
+ "lines": 22868,
195
+ "weight": 2
196
+ },
197
+ {
198
+ "name": "stackexchange_title_body/crypto.stackexchange.com.jsonl.gz",
199
+ "lines": 23231,
200
+ "weight": 2
201
+ },
202
+ {
203
+ "name": "stackexchange_title_body/cooking.stackexchange.com.jsonl.gz",
204
+ "lines": 23705,
205
+ "weight": 2
206
+ },
207
+ {
208
+ "name": "stackexchange_title_body/photo.stackexchange.com.jsonl.gz",
209
+ "lines": 23753,
210
+ "weight": 2
211
+ },
212
+ {
213
+ "name": "stackexchange_title_body/workplace.stackexchange.com.jsonl.gz",
214
+ "lines": 24189,
215
+ "weight": 2
216
+ },
217
+ {
218
+ "name": "stackexchange_title_body/biology.stackexchange.com.jsonl.gz",
219
+ "lines": 24447,
220
+ "weight": 2
221
+ },
222
+ {
223
+ "name": "stackexchange_title_body/bitcoin.stackexchange.com.jsonl.gz",
224
+ "lines": 25374,
225
+ "weight": 2
226
+ },
227
+ {
228
+ "name": "stackexchange_title_body/worldbuilding.stackexchange.com.jsonl.gz",
229
+ "lines": 26763,
230
+ "weight": 2
231
+ },
232
+ {
233
+ "name": "stackexchange_title_body/datascience.stackexchange.com.jsonl.gz",
234
+ "lines": 27397,
235
+ "weight": 2
236
+ },
237
+ {
238
+ "name": "stackexchange_title_body/ux.stackexchange.com.jsonl.gz",
239
+ "lines": 29403,
240
+ "weight": 2
241
+ },
242
+ {
243
+ "name": "stackexchange_title_body/webapps.stackexchange.com.jsonl.gz",
244
+ "lines": 29697,
245
+ "weight": 2
246
+ },
247
+ {
248
+ "name": "stackexchange_title_body/graphicdesign.stackexchange.com.jsonl.gz",
249
+ "lines": 30233,
250
+ "weight": 2
251
+ },
252
+ {
253
+ "name": "stackexchange_title_body/raspberrypi.stackexchange.com.jsonl.gz",
254
+ "lines": 30625,
255
+ "weight": 2
256
+ },
257
+ {
258
+ "name": "stackexchange_title_body/money.stackexchange.com.jsonl.gz",
259
+ "lines": 32021,
260
+ "weight": 2
261
+ },
262
+ {
263
+ "name": "stackexchange_title_body/judaism.stackexchange.com.jsonl.gz",
264
+ "lines": 32028,
265
+ "weight": 2
266
+ },
267
+ {
268
+ "name": "stackexchange_title_body/ethereum.stackexchange.com.jsonl.gz",
269
+ "lines": 32760,
270
+ "weight": 2
271
+ },
272
+ {
273
+ "name": "stackexchange_title_body/academia.stackexchange.com.jsonl.gz",
274
+ "lines": 34331,
275
+ "weight": 2
276
+ },
277
+ {
278
+ "name": "stackexchange_title_body/chemistry.stackexchange.com.jsonl.gz",
279
+ "lines": 34506,
280
+ "weight": 2
281
+ },
282
+ {
283
+ "name": "stackexchange_title_body/webmasters.stackexchange.com.jsonl.gz",
284
+ "lines": 34559,
285
+ "weight": 2
286
+ },
287
+ {
288
+ "name": "stackexchange_title_body/meta.stackoverflow.com.jsonl.gz",
289
+ "lines": 36456,
290
+ "weight": 2
291
+ },
292
+ {
293
+ "name": "stackexchange_title_body/cs.stackexchange.com.jsonl.gz",
294
+ "lines": 38314,
295
+ "weight": 2
296
+ },
297
+ {
298
+ "name": "stackexchange_title_body/travel.stackexchange.com.jsonl.gz",
299
+ "lines": 41227,
300
+ "weight": 2
301
+ },
302
+ {
303
+ "name": "stackexchange_title_body/rpg.stackexchange.com.jsonl.gz",
304
+ "lines": 42303,
305
+ "weight": 2
306
+ },
307
+ {
308
+ "name": "stackexchange_title_body/codereview.stackexchange.com.jsonl.gz",
309
+ "lines": 45765,
310
+ "weight": 3
311
+ },
312
+ {
313
+ "name": "stackexchange_title_body/gamedev.stackexchange.com.jsonl.gz",
314
+ "lines": 46485,
315
+ "weight": 3
316
+ },
317
+ {
318
+ "name": "stackexchange_title_body/android.stackexchange.com.jsonl.gz",
319
+ "lines": 51608,
320
+ "weight": 3
321
+ },
322
+ {
323
+ "name": "stackexchange_title_body/softwareengineering.stackexchange.com.jsonl.gz",
324
+ "lines": 53942,
325
+ "weight": 3
326
+ },
327
+ {
328
+ "name": "stackexchange_title_body/security.stackexchange.com.jsonl.gz",
329
+ "lines": 58000,
330
+ "weight": 3
331
+ },
332
+ {
333
+ "name": "stackexchange_title_body/diy.stackexchange.com.jsonl.gz",
334
+ "lines": 60083,
335
+ "weight": 3
336
+ },
337
+ {
338
+ "name": "stackexchange_title_body/scifi.stackexchange.com.jsonl.gz",
339
+ "lines": 61528,
340
+ "weight": 3
341
+ },
342
+ {
343
+ "name": "stackexchange_title_body/mathematica.stackexchange.com.jsonl.gz",
344
+ "lines": 73131,
345
+ "weight": 4
346
+ },
347
+ {
348
+ "name": "TriviaQA_pairs.jsonl.gz",
349
+ "lines": 73346,
350
+ "weight": 4
351
+ },
352
+ {
353
+ "name": "stackexchange_title_body/drupal.stackexchange.com.jsonl.gz",
354
+ "lines": 79717,
355
+ "weight": 4
356
+ },
357
+ {
358
+ "name": "stackexchange_title_body/blender.stackexchange.com.jsonl.gz",
359
+ "lines": 80766,
360
+ "weight": 4
361
+ },
362
+ {
363
+ "name": "stackexchange_title_body/dba.stackexchange.com.jsonl.gz",
364
+ "lines": 81871,
365
+ "weight": 4
366
+ },
367
+ {
368
+ "name": "stackexchange_title_body/ell.stackexchange.com.jsonl.gz",
369
+ "lines": 83271,
370
+ "weight": 4
371
+ },
372
+ {
373
+ "name": "stackexchange_title_body/meta.stackexchange.com.jsonl.gz",
374
+ "lines": 83510,
375
+ "weight": 4
376
+ },
377
+ {
378
+ "name": "squad_pairs.jsonl.gz",
379
+ "lines": 87599,
380
+ "weight": 5
381
+ },
382
+ {
383
+ "name": "stackexchange_title_body/gaming.stackexchange.com.jsonl.gz",
384
+ "lines": 88912,
385
+ "weight": 5
386
+ },
387
+ {
388
+ "name": "stackexchange_title_body/sharepoint.stackexchange.com.jsonl.gz",
389
+ "lines": 94011,
390
+ "weight": 5
391
+ },
392
+ {
393
+ "name": "stackexchange_title_body/magento.stackexchange.com.jsonl.gz",
394
+ "lines": 99991,
395
+ "weight": 5
396
+ },
397
+ {
398
+ "name": "NQ-train_pairs.jsonl.gz",
399
+ "lines": 100231,
400
+ "weight": 5
401
+ },
402
+ {
403
+ "name": "stackexchange_title_body/wordpress.stackexchange.com.jsonl.gz",
404
+ "lines": 100474,
405
+ "weight": 5
406
+ },
407
+ {
408
+ "name": "SimpleWiki.jsonl.gz",
409
+ "lines": 102225,
410
+ "weight": 5
411
+ },
412
+ {
413
+ "name": "quora_duplicates_triplets.jsonl.gz",
414
+ "lines": 103663,
415
+ "weight": 5
416
+ },
417
+ {
418
+ "name": "stackexchange_title_body/salesforce.stackexchange.com.jsonl.gz",
419
+ "lines": 105260,
420
+ "weight": 5
421
+ },
422
+ {
423
+ "name": "stackexchange_title_body/english.stackexchange.com.jsonl.gz",
424
+ "lines": 109522,
425
+ "weight": 6
426
+ },
427
+ {
428
+ "name": "stackexchange_title_body/apple.stackexchange.com.jsonl.gz",
429
+ "lines": 110622,
430
+ "weight": 6
431
+ },
432
+ {
433
+ "name": "altlex.jsonl.gz",
434
+ "lines": 112696,
435
+ "weight": 6
436
+ },
437
+ {
438
+ "name": "stackexchange_title_body/mathoverflow.net.jsonl.gz",
439
+ "lines": 120851,
440
+ "weight": 6
441
+ },
442
+ {
443
+ "name": "wikihow.jsonl.gz",
444
+ "lines": 128542,
445
+ "weight": 6
446
+ },
447
+ {
448
+ "name": "stackexchange_title_body/gis.stackexchange.com.jsonl.gz",
449
+ "lines": 131000,
450
+ "weight": 7
451
+ },
452
+ {
453
+ "name": "stackexchange_title_body/electronics.stackexchange.com.jsonl.gz",
454
+ "lines": 143582,
455
+ "weight": 7
456
+ },
457
+ {
458
+ "name": "stackexchange_title_body/physics.stackexchange.com.jsonl.gz",
459
+ "lines": 173307,
460
+ "weight": 9
461
+ },
462
+ {
463
+ "name": "stackexchange_title_body/stats.stackexchange.com.jsonl.gz",
464
+ "lines": 173466,
465
+ "weight": 9
466
+ },
467
+ {
468
+ "name": "sentence-compression.jsonl.gz",
469
+ "lines": 180000,
470
+ "weight": 9
471
+ },
472
+ {
473
+ "name": "stackexchange_title_body/unix.stackexchange.com.jsonl.gz",
474
+ "lines": 185997,
475
+ "weight": 9
476
+ },
477
+ {
478
+ "name": "stackexchange_title_body/tex.stackexchange.com.jsonl.gz",
479
+ "lines": 202954,
480
+ "weight": 10
481
+ },
482
+ {
483
+ "name": "stackexchange_duplicate_questions_title-body_title-body.jsonl.gz",
484
+ "lines": 250460,
485
+ "weight": 12
486
+ },
487
+ {
488
+ "name": "stackexchange_duplicate_questions_body_body.jsonl.gz",
489
+ "lines": 250519,
490
+ "weight": 12
491
+ },
492
+ {
493
+ "name": "stackexchange_title_body/serverfault.com.jsonl.gz",
494
+ "lines": 270904,
495
+ "weight": 13
496
+ },
497
+ {
498
+ "name": "AllNLI.jsonl.gz",
499
+ "lines": 277230,
500
+ "weight": 13
501
+ },
502
+ {
503
+ "name": "stackexchange_duplicate_questions_title_title.jsonl.gz",
504
+ "lines": 304525,
505
+ "weight": 15
506
+ },
507
+ {
508
+ "name": "eli5_question_answer.jsonl.gz",
509
+ "lines": 325475,
510
+ "weight": 16
511
+ },
512
+ {
513
+ "name": "specter_train_triples.jsonl.gz",
514
+ "lines": 684100,
515
+ "weight": 16
516
+ },
517
+ {
518
+ "name": "stackexchange_title_body/askubuntu.com.jsonl.gz",
519
+ "lines": 347925,
520
+ "weight": 17
521
+ },
522
+ {
523
+ "name": "stackexchange_title_body/superuser.com.jsonl.gz",
524
+ "lines": 435463,
525
+ "weight": 21
526
+ },
527
+ {
528
+ "name": "stackexchange_title_body/small_stackexchanges.jsonl.gz",
529
+ "lines": 448146,
530
+ "weight": 21
531
+ },
532
+ {
533
+ "name": "S2ORC_title_abstract.jsonl.gz",
534
+ "lines": 41769185,
535
+ "weight": 23
536
+ },
537
+ {
538
+ "name": "S2ORC_citation_pairs.jsonl.gz",
539
+ "lines": 52603982,
540
+ "weight": 12
541
+ },
542
+ {
543
+ "name": "S2ORC_citation_pairs_abstract.jsonl.gz",
544
+ "lines": 116288806,
545
+ "weight": 12
546
+ },
547
+ {
548
+ "name": "PAQ_pairs.jsonl.gz",
549
+ "lines": 64371441,
550
+ "weight": 23
551
+ },
552
+ {
553
+ "name": "WikiAnswers_pairs.jsonl.gz",
554
+ "lines": 77427422,
555
+ "weight": 23
556
+ },
557
+ {
558
+ "name": "searchQA_question_top5_snippets_merged.jsonl.gz",
559
+ "lines": 582261,
560
+ "weight": 28
561
+ },
562
+ {
563
+ "name": "yahoo_answers_title_question.jsonl.gz",
564
+ "lines": 659896,
565
+ "weight": 31
566
+ },
567
+ {
568
+ "name": "yahoo_answers_question_answer.jsonl.gz",
569
+ "lines": 681164,
570
+ "weight": 32
571
+ },
572
+ {
573
+ "name": "yahoo_answers_title_answer.jsonl.gz",
574
+ "lines": 1198260,
575
+ "weight": 47
576
+ },
577
+ {
578
+ "name": "stackexchange_title_body/math.stackexchange.com.jsonl.gz",
579
+ "lines": 1338443,
580
+ "weight": 47
581
+ },
582
+ {
583
+ "name": "gooaq_pairs.jsonl.gz",
584
+ "lines": 3012496,
585
+ "weight": 47
586
+ },
587
+ {
588
+ "name": "msmarco-query_passage_negative.jsonl.gz",
589
+ "lines": 9144553,
590
+ "weight": 47
591
+ },
592
+ {
593
+ "name": "stackexchange_title_body/stackoverflow.com-Posts.jsonl.gz",
594
+ "lines": 18562443,
595
+ "weight": 47
596
+ },
597
+ {"name": "reddit/reddit_2015.jsonl.gz", "weight": 50},
598
+ {"name": "reddit/reddit_2016.jsonl.gz", "weight": 50},
599
+ {"name": "reddit/reddit_2017.jsonl.gz", "weight": 50},
600
+ {"name": "reddit/reddit_2018.jsonl.gz", "weight": 50}
601
+ ]
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:879a1f3543bda9609f8ae74c68236cc5049769fcac6fbd68a70aafd6762dca01
3
+ size 438011953
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "[UNK]", "pad_token": "<pad>", "mask_token": "<mask>", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "output/all_datasets_v3_mpnet-base/120000", "tokenizer_class": "MPNetTokenizer"}
train_script.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, normalize=True):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ self.model = AutoModel.from_pretrained(model_name)
45
+ self.normalize = normalize
46
+ self.tokenizer = tokenizer
47
+
48
+ def forward(self, **kwargs):
49
+ model_output = self.model(**kwargs)
50
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
51
+ if self.normalize:
52
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53
+
54
+ return embeddings
55
+
56
+ def mean_pooling(self, model_output, attention_mask):
57
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
58
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
59
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
60
+
61
+ def save_pretrained(self, output_path):
62
+ if xm.is_master_ordinal():
63
+ self.tokenizer.save_pretrained(output_path)
64
+ self.model.config.save_pretrained(output_path)
65
+
66
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
67
+
68
+
69
+
70
+
71
+ def train_function(index, args, queue):
72
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
73
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer)
74
+
75
+
76
+ ### Train Loop
77
+ device = xm.xla_device()
78
+ model = model.to(device)
79
+
80
+ # Instantiate optimizer
81
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
82
+
83
+ lr_scheduler = get_linear_schedule_with_warmup(
84
+ optimizer=optimizer,
85
+ num_warmup_steps=500,
86
+ num_training_steps=args.steps,
87
+ )
88
+
89
+ # Now we train the model
90
+ cross_entropy_loss = nn.CrossEntropyLoss()
91
+ max_grad_norm = 1
92
+
93
+ model.train()
94
+
95
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
96
+ #### Get the batch data
97
+ batch = queue.get()
98
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
99
+
100
+
101
+ if len(batch[0]) == 2: #(anchor, positive)
102
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
103
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
104
+
105
+ ### Compute embeddings
106
+ embeddings_a = model(**text1.to(device))
107
+ embeddings_b = model(**text2.to(device))
108
+
109
+ ### Gather all embedings
110
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
111
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
112
+
113
+ ### Compute similarity scores 512 x 512
114
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
115
+
116
+ ### Compute cross-entropy loss
117
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
118
+
119
+ ## Symmetric loss as in CLIP
120
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
121
+
122
+ else: #(anchor, positive, negative)
123
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
124
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
125
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
126
+
127
+ embeddings_a = model(**text1.to(device))
128
+ embeddings_b1 = model(**text2.to(device))
129
+ embeddings_b2 = model(**text3.to(device))
130
+
131
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
132
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
133
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
134
+
135
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
136
+
137
+ ### Compute similarity scores 512 x 1024
138
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
139
+
140
+ ### Compute cross-entropy loss
141
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
142
+
143
+ ## One-way loss
144
+ loss = cross_entropy_loss(scores, labels)
145
+
146
+
147
+ # Backward pass
148
+ optimizer.zero_grad()
149
+ loss.backward()
150
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
151
+
152
+ xm.optimizer_step(optimizer, barrier=True)
153
+ lr_scheduler.step()
154
+
155
+
156
+ #Save model
157
+ if (global_step+1) % args.save_steps == 0:
158
+ output_path = os.path.join(args.output, str(global_step+1))
159
+ xm.master_print("save model: "+output_path)
160
+ model.save_pretrained(output_path)
161
+
162
+
163
+ output_path = os.path.join(args.output, "final")
164
+ xm.master_print("save model final: "+ output_path)
165
+ model.save_pretrained(output_path)
166
+
167
+
168
+ def produce_data(args, queue, filepaths, dataset_indices):
169
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
170
+ size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
171
+ num_same_dataset = int(size_per_dataset / args.batch_size)
172
+ print("producer", "global_batch_size", global_batch_size)
173
+ print("producer", "size_per_dataset", size_per_dataset)
174
+ print("producer", "num_same_dataset", num_same_dataset)
175
+
176
+ datasets = []
177
+ for filepath in filepaths:
178
+ if "reddit_" in filepath: #Special dataset class for Reddit files
179
+ data_obj = RedditDataset(filepath)
180
+ else:
181
+ data_obj = Dataset(filepath)
182
+ datasets.append(iter(data_obj))
183
+
184
+ # Store if dataset is in a 2 col or 3 col format
185
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
186
+
187
+ while True:
188
+ texts_in_batch = set()
189
+ batch_format = None #2 vs 3 col format for this batch
190
+
191
+ #Add data from several sub datasets
192
+ for _ in range(args.datasets_per_batch):
193
+ valid_dataset = False #Check that datasets have the same 2/3 col format
194
+ while not valid_dataset:
195
+ data_idx = random.choice(dataset_indices)
196
+ if batch_format is None:
197
+ batch_format = num_cols[data_idx]
198
+ valid_dataset = True
199
+ else: #Check that this dataset has the same format
200
+ valid_dataset = (batch_format == num_cols[data_idx])
201
+
202
+ #Get data from this dataset
203
+ dataset = datasets[data_idx]
204
+ for _ in range(num_same_dataset):
205
+ for _ in range(args.nprocs):
206
+ batch_device = [] #A batch for one device
207
+ while len(batch_device) < args.batch_size:
208
+ sample = next(dataset)
209
+ in_batch = False
210
+ for text in sample:
211
+ if text in texts_in_batch:
212
+ in_batch = True
213
+ break
214
+
215
+ if not in_batch:
216
+ for text in sample:
217
+ texts_in_batch.add(text)
218
+ batch_device.append(sample)
219
+
220
+ queue.put(batch_device)
221
+
222
+
223
+ class RedditDataset:
224
+ """
225
+ A class that handles the reddit data files
226
+ """
227
+ def __init__(self, filepath):
228
+ self.filepath = filepath
229
+
230
+ def __iter__(self):
231
+ while True:
232
+ with gzip.open(self.filepath, "rt") as fIn:
233
+ for line in fIn:
234
+ data = json.loads(line)
235
+
236
+ if "response" in data and "context" in data:
237
+ yield [data["response"], data["context"]]
238
+
239
+ class Dataset:
240
+ """
241
+ A class that handles one dataset
242
+ """
243
+ def __init__(self, filepath):
244
+ self.filepath = filepath
245
+
246
+ def __iter__(self):
247
+ max_dataset_size = 10*1000*1000 #Cache small datasets in memory
248
+ dataset = []
249
+ data_format = None
250
+
251
+ while dataset is None or len(dataset) == 0:
252
+ with gzip.open(self.filepath, "rt") as fIn:
253
+ for line in fIn:
254
+ data = json.loads(line)
255
+ if isinstance(data, dict):
256
+ data = data['texts']
257
+
258
+ if data_format is None:
259
+ data_format = len(data)
260
+
261
+ #Ensure that all entries are of the same 2/3 col format
262
+ assert len(data) == data_format
263
+
264
+ if dataset is not None:
265
+ dataset.append(data)
266
+ if len(dataset) >= max_dataset_size:
267
+ dataset = None
268
+
269
+ yield data
270
+
271
+ # Data loaded. Now stream to the queue
272
+ # Shuffle for each epoch
273
+ while True:
274
+ random.shuffle(dataset)
275
+ for data in dataset:
276
+ yield data
277
+
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser()
282
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
283
+ parser.add_argument('--steps', type=int, default=2000)
284
+ parser.add_argument('--save_steps', type=int, default=10000)
285
+ parser.add_argument('--batch_size', type=int, default=64)
286
+ parser.add_argument('--max_length', type=int, default=128)
287
+ parser.add_argument('--nprocs', type=int, default=8)
288
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
289
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
290
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
291
+ parser.add_argument('data_config', help="A data_config.json file")
292
+ parser.add_argument('output')
293
+ args = parser.parse_args()
294
+
295
+ # Ensure global batch size is divisble by data_sample_size
296
+ assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
297
+
298
+ logging.info("Output: "+args.output)
299
+ if os.path.exists(args.output):
300
+ print("Output folder already exists.")
301
+ input("Continue?")
302
+
303
+ # Write train script to output path
304
+ os.makedirs(args.output, exist_ok=True)
305
+
306
+ data_config_path = os.path.join(args.output, 'data_config.json')
307
+ copyfile(args.data_config, data_config_path)
308
+
309
+ train_script_path = os.path.join(args.output, 'train_script.py')
310
+ copyfile(__file__, train_script_path)
311
+ with open(train_script_path, 'a') as fOut:
312
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
313
+
314
+
315
+
316
+ #Load data config
317
+ with open(args.data_config) as fIn:
318
+ data_config = json.load(fIn)
319
+
320
+ queue = mp.Queue(maxsize=100*args.nprocs)
321
+
322
+ filepaths = []
323
+ dataset_indices = []
324
+ for idx, data in enumerate(data_config):
325
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
326
+ dataset_indices.extend([idx]*data['weight'])
327
+
328
+ # Start producer
329
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
330
+ p.start()
331
+
332
+ # Run training
333
+ print("Start processes:", args.nprocs)
334
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
335
+ print("Training done")
336
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
337
+ print("With 'pkill python' you can kill all remaining python processes")
338
+ p.kill()
339
+ exit()
340
+
341
+
342
+
343
+ # Script was called via:
344
+ #python train_many_data_files_v2.py --steps 1000000 --batch_size 64 --model output/all_datasets_v3_mpnet-base/120000 train_data_configs/all_datasets_v3.json output/all_datasets_v3_mpnet-base_cnt_120k
vocab.txt ADDED
The diff for this file is too large to render. See raw diff