nreimers commited on
Commit
d83c6d8
1 Parent(s): c1d5bdc
1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": true,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: sentence-similarity
3
+ tags:
4
+ - sentence-transformers
5
+ - feature-extraction
6
+ - sentence-similarity
7
+ ---
8
+
9
+ # multi-qa_v1-distilbert-cls_dot
10
+
11
+ This is a distilbert-base-uncased model trained on all the Q&A datasets of the 1B+ train corpus. It was trained with the v1 setup. See data_config.json and train_script.py in this respository how the model was trained and which datasets have been used.
12
+
13
+ ## Usage
14
+ It can be used for semantic search. Output vectors are **not normalized**. You can find relevant passages by using **dot-product**.
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "distilbert-base-uncased",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForMaskedLM"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "initializer_range": 0.02,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "distilbert",
14
+ "n_heads": 12,
15
+ "n_layers": 6,
16
+ "pad_token_id": 0,
17
+ "qa_dropout": 0.1,
18
+ "seq_classif_dropout": 0.2,
19
+ "sinusoidal_pos_embds": false,
20
+ "tie_weights_": true,
21
+ "transformers_version": "4.8.2",
22
+ "vocab_size": 30522
23
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.0.0",
4
+ "transformers": "4.6.1",
5
+ "pytorch": "1.8.1"
6
+ }
7
+ }
data_config.json ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "stackexchange_title_body/skeptics.stackexchange.com.jsonl.gz",
4
+ "lines": 10009,
5
+ "weight": 3
6
+ },
7
+ {
8
+ "name": "stackexchange_Title_Answer/islam.stackexchange.com.jsonl.gz",
9
+ "lines": 10052,
10
+ "weight": 3
11
+ },
12
+ {
13
+ "name": "stackexchange_Title_Answer/anime.stackexchange.com.jsonl.gz",
14
+ "lines": 10131,
15
+ "weight": 3
16
+ },
17
+ {
18
+ "name": "stackexchange_title_body/writers.stackexchange.com.jsonl.gz",
19
+ "lines": 10157,
20
+ "weight": 3
21
+ },
22
+ {
23
+ "name": "stackexchange_title_body/astronomy.stackexchange.com.jsonl.gz",
24
+ "lines": 10462,
25
+ "weight": 3
26
+ },
27
+ {
28
+ "name": "stackexchange_title_body/vi.stackexchange.com.jsonl.gz",
29
+ "lines": 10551,
30
+ "weight": 3
31
+ },
32
+ {
33
+ "name": "stackexchange_Title_Answer/french.stackexchange.com.jsonl.gz",
34
+ "lines": 10578,
35
+ "weight": 3
36
+ },
37
+ {
38
+ "name": "stackexchange_title_body/cstheory.stackexchange.com.jsonl.gz",
39
+ "lines": 10642,
40
+ "weight": 3
41
+ },
42
+ {
43
+ "name": "stackexchange_Title_Answer/civicrm.stackexchange.com.jsonl.gz",
44
+ "lines": 10648,
45
+ "weight": 3
46
+ },
47
+ {
48
+ "name": "stackexchange_Title_Answer/expressionengine.stackexchange.com.jsonl.gz",
49
+ "lines": 10742,
50
+ "weight": 3
51
+ },
52
+ {
53
+ "name": "stackexchange_title_body/engineering.stackexchange.com.jsonl.gz",
54
+ "lines": 10753,
55
+ "weight": 3
56
+ },
57
+ {
58
+ "name": "stackexchange_Title_Answer/history.stackexchange.com.jsonl.gz",
59
+ "lines": 10766,
60
+ "weight": 3
61
+ },
62
+ {
63
+ "name": "stackexchange_title_body/french.stackexchange.com.jsonl.gz",
64
+ "lines": 10794,
65
+ "weight": 3
66
+ },
67
+ {
68
+ "name": "stackexchange_Title_Answer/politics.stackexchange.com.jsonl.gz",
69
+ "lines": 11047,
70
+ "weight": 3
71
+ },
72
+ {
73
+ "name": "stackexchange_title_body/economics.stackexchange.com.jsonl.gz",
74
+ "lines": 11115,
75
+ "weight": 3
76
+ },
77
+ {
78
+ "name": "stackexchange_Title_Answer/craftcms.stackexchange.com.jsonl.gz",
79
+ "lines": 11236,
80
+ "weight": 3
81
+ },
82
+ {
83
+ "name": "stackexchange_title_body/anime.stackexchange.com.jsonl.gz",
84
+ "lines": 11444,
85
+ "weight": 3
86
+ },
87
+ {
88
+ "name": "stackexchange_Title_Answer/christianity.stackexchange.com.jsonl.gz",
89
+ "lines": 11498,
90
+ "weight": 3
91
+ },
92
+ {
93
+ "name": "stackexchange_Title_Answer/softwarerecs.stackexchange.com.jsonl.gz",
94
+ "lines": 11761,
95
+ "weight": 3
96
+ },
97
+ {
98
+ "name": "stackexchange_Title_Answer/boardgames.stackexchange.com.jsonl.gz",
99
+ "lines": 11805,
100
+ "weight": 3
101
+ },
102
+ {
103
+ "name": "stackexchange_title_body/islam.stackexchange.com.jsonl.gz",
104
+ "lines": 11853,
105
+ "weight": 3
106
+ },
107
+ {
108
+ "name": "stackexchange_title_body/expressionengine.stackexchange.com.jsonl.gz",
109
+ "lines": 11866,
110
+ "weight": 3
111
+ },
112
+ {
113
+ "name": "stackexchange_title_body/politics.stackexchange.com.jsonl.gz",
114
+ "lines": 11894,
115
+ "weight": 3
116
+ },
117
+ {
118
+ "name": "stackexchange_title_body/history.stackexchange.com.jsonl.gz",
119
+ "lines": 12021,
120
+ "weight": 3
121
+ },
122
+ {
123
+ "name": "stackexchange_title_body/christianity.stackexchange.com.jsonl.gz",
124
+ "lines": 12108,
125
+ "weight": 3
126
+ },
127
+ {
128
+ "name": "stackexchange_title_body/boardgames.stackexchange.com.jsonl.gz",
129
+ "lines": 12149,
130
+ "weight": 3
131
+ },
132
+ {
133
+ "name": "stackexchange_title_body/civicrm.stackexchange.com.jsonl.gz",
134
+ "lines": 12543,
135
+ "weight": 3
136
+ },
137
+ {
138
+ "name": "stackexchange_title_body/craftcms.stackexchange.com.jsonl.gz",
139
+ "lines": 12574,
140
+ "weight": 3
141
+ },
142
+ {
143
+ "name": "stackexchange_Title_Answer/networkengineering.stackexchange.com.jsonl.gz",
144
+ "lines": 12590,
145
+ "weight": 3
146
+ },
147
+ {
148
+ "name": "stackexchange_Title_Answer/space.stackexchange.com.jsonl.gz",
149
+ "lines": 12893,
150
+ "weight": 3
151
+ },
152
+ {
153
+ "name": "stackexchange_Title_Answer/quant.stackexchange.com.jsonl.gz",
154
+ "lines": 12933,
155
+ "weight": 3
156
+ },
157
+ {
158
+ "name": "stackexchange_Title_Answer/philosophy.stackexchange.com.jsonl.gz",
159
+ "lines": 13114,
160
+ "weight": 3
161
+ },
162
+ {
163
+ "name": "stackexchange_Title_Answer/gardening.stackexchange.com.jsonl.gz",
164
+ "lines": 13246,
165
+ "weight": 3
166
+ },
167
+ {
168
+ "name": "stackexchange_title_body/hinduism.stackexchange.com.jsonl.gz",
169
+ "lines": 13450,
170
+ "weight": 4
171
+ },
172
+ {
173
+ "name": "stackexchange_title_body/networkengineering.stackexchange.com.jsonl.gz",
174
+ "lines": 13454,
175
+ "weight": 4
176
+ },
177
+ {
178
+ "name": "stackexchange_Title_Answer/german.stackexchange.com.jsonl.gz",
179
+ "lines": 13733,
180
+ "weight": 4
181
+ },
182
+ {
183
+ "name": "stackexchange_title_body/german.stackexchange.com.jsonl.gz",
184
+ "lines": 13950,
185
+ "weight": 4
186
+ },
187
+ {
188
+ "name": "stackexchange_title_body/philosophy.stackexchange.com.jsonl.gz",
189
+ "lines": 14829,
190
+ "weight": 4
191
+ },
192
+ {
193
+ "name": "stackexchange_title_body/gardening.stackexchange.com.jsonl.gz",
194
+ "lines": 15136,
195
+ "weight": 4
196
+ },
197
+ {
198
+ "name": "stackexchange_title_body/space.stackexchange.com.jsonl.gz",
199
+ "lines": 15142,
200
+ "weight": 4
201
+ },
202
+ {
203
+ "name": "stackexchange_Title_Answer/bicycles.stackexchange.com.jsonl.gz",
204
+ "lines": 15708,
205
+ "weight": 4
206
+ },
207
+ {
208
+ "name": "stackexchange_Title_Answer/law.stackexchange.com.jsonl.gz",
209
+ "lines": 16133,
210
+ "weight": 4
211
+ },
212
+ {
213
+ "name": "stackexchange_Title_Answer/arduino.stackexchange.com.jsonl.gz",
214
+ "lines": 16281,
215
+ "weight": 4
216
+ },
217
+ {
218
+ "name": "stackexchange_title_body/bicycles.stackexchange.com.jsonl.gz",
219
+ "lines": 16353,
220
+ "weight": 4
221
+ },
222
+ {
223
+ "name": "stackexchange_Title_Answer/emacs.stackexchange.com.jsonl.gz",
224
+ "lines": 16830,
225
+ "weight": 4
226
+ },
227
+ {
228
+ "name": "stackexchange_title_body/quant.stackexchange.com.jsonl.gz",
229
+ "lines": 17261,
230
+ "weight": 4
231
+ },
232
+ {
233
+ "name": "stackexchange_Title_Answer/dsp.stackexchange.com.jsonl.gz",
234
+ "lines": 17430,
235
+ "weight": 4
236
+ },
237
+ {
238
+ "name": "stackexchange_Title_Answer/puzzling.stackexchange.com.jsonl.gz",
239
+ "lines": 17448,
240
+ "weight": 4
241
+ },
242
+ {
243
+ "name": "stackexchange_title_body/puzzling.stackexchange.com.jsonl.gz",
244
+ "lines": 17851,
245
+ "weight": 5
246
+ },
247
+ {
248
+ "name": "stackexchange_title_body/law.stackexchange.com.jsonl.gz",
249
+ "lines": 17941,
250
+ "weight": 5
251
+ },
252
+ {
253
+ "name": "stackexchange_Title_Answer/movies.stackexchange.com.jsonl.gz",
254
+ "lines": 18243,
255
+ "weight": 5
256
+ },
257
+ {
258
+ "name": "stackexchange_Title_Answer/mechanics.stackexchange.com.jsonl.gz",
259
+ "lines": 18613,
260
+ "weight": 5
261
+ },
262
+ {
263
+ "name": "stackexchange_Title_Answer/aviation.stackexchange.com.jsonl.gz",
264
+ "lines": 18755,
265
+ "weight": 5
266
+ },
267
+ {
268
+ "name": "stackexchange_Title_Answer/biology.stackexchange.com.jsonl.gz",
269
+ "lines": 19277,
270
+ "weight": 5
271
+ },
272
+ {
273
+ "name": "stackexchange_Title_Answer/crypto.stackexchange.com.jsonl.gz",
274
+ "lines": 19404,
275
+ "weight": 5
276
+ },
277
+ {
278
+ "name": "stackexchange_title_body/arduino.stackexchange.com.jsonl.gz",
279
+ "lines": 19553,
280
+ "weight": 5
281
+ },
282
+ {
283
+ "name": "stackexchange_Title_Answer/music.stackexchange.com.jsonl.gz",
284
+ "lines": 19936,
285
+ "weight": 5
286
+ },
287
+ {
288
+ "name": "stackexchange_title_body/aviation.stackexchange.com.jsonl.gz",
289
+ "lines": 20139,
290
+ "weight": 5
291
+ },
292
+ {
293
+ "name": "stackexchange_title_body/softwarerecs.stackexchange.com.jsonl.gz",
294
+ "lines": 20142,
295
+ "weight": 5
296
+ },
297
+ {
298
+ "name": "stackexchange_title_body/movies.stackexchange.com.jsonl.gz",
299
+ "lines": 20181,
300
+ "weight": 5
301
+ },
302
+ {
303
+ "name": "stackexchange_Title_Answer/datascience.stackexchange.com.jsonl.gz",
304
+ "lines": 20503,
305
+ "weight": 5
306
+ },
307
+ {
308
+ "name": "stackexchange_title_body/music.stackexchange.com.jsonl.gz",
309
+ "lines": 20636,
310
+ "weight": 5
311
+ },
312
+ {
313
+ "name": "stackexchange_Title_Answer/japanese.stackexchange.com.jsonl.gz",
314
+ "lines": 20948,
315
+ "weight": 5
316
+ },
317
+ {
318
+ "name": "stackexchange_title_body/emacs.stackexchange.com.jsonl.gz",
319
+ "lines": 21055,
320
+ "weight": 5
321
+ },
322
+ {
323
+ "name": "stackexchange_title_body/dsp.stackexchange.com.jsonl.gz",
324
+ "lines": 21252,
325
+ "weight": 5
326
+ },
327
+ {
328
+ "name": "stackexchange_title_body/japanese.stackexchange.com.jsonl.gz",
329
+ "lines": 22056,
330
+ "weight": 5
331
+ },
332
+ {
333
+ "name": "stackexchange_Title_Answer/bitcoin.stackexchange.com.jsonl.gz",
334
+ "lines": 22474,
335
+ "weight": 6
336
+ },
337
+ {
338
+ "name": "stackexchange_Title_Answer/cooking.stackexchange.com.jsonl.gz",
339
+ "lines": 22641,
340
+ "weight": 6
341
+ },
342
+ {
343
+ "name": "stackexchange_title_body/mechanics.stackexchange.com.jsonl.gz",
344
+ "lines": 22868,
345
+ "weight": 6
346
+ },
347
+ {
348
+ "name": "stackexchange_Title_Answer/photo.stackexchange.com.jsonl.gz",
349
+ "lines": 23204,
350
+ "weight": 6
351
+ },
352
+ {
353
+ "name": "stackexchange_title_body/crypto.stackexchange.com.jsonl.gz",
354
+ "lines": 23231,
355
+ "weight": 6
356
+ },
357
+ {
358
+ "name": "stackexchange_title_body/cooking.stackexchange.com.jsonl.gz",
359
+ "lines": 23705,
360
+ "weight": 6
361
+ },
362
+ {
363
+ "name": "stackexchange_title_body/photo.stackexchange.com.jsonl.gz",
364
+ "lines": 23753,
365
+ "weight": 6
366
+ },
367
+ {
368
+ "name": "stackexchange_Title_Answer/workplace.stackexchange.com.jsonl.gz",
369
+ "lines": 24012,
370
+ "weight": 6
371
+ },
372
+ {
373
+ "name": "stackexchange_Title_Answer/meta.stackoverflow.com.jsonl.gz",
374
+ "lines": 24044,
375
+ "weight": 6
376
+ },
377
+ {
378
+ "name": "stackexchange_Title_Answer/raspberrypi.stackexchange.com.jsonl.gz",
379
+ "lines": 24143,
380
+ "weight": 6
381
+ },
382
+ {
383
+ "name": "stackexchange_title_body/workplace.stackexchange.com.jsonl.gz",
384
+ "lines": 24189,
385
+ "weight": 6
386
+ },
387
+ {
388
+ "name": "stackexchange_title_body/biology.stackexchange.com.jsonl.gz",
389
+ "lines": 24447,
390
+ "weight": 6
391
+ },
392
+ {
393
+ "name": "stackexchange_Title_Answer/webapps.stackexchange.com.jsonl.gz",
394
+ "lines": 24867,
395
+ "weight": 6
396
+ },
397
+ {
398
+ "name": "stackexchange_title_body/bitcoin.stackexchange.com.jsonl.gz",
399
+ "lines": 25374,
400
+ "weight": 6
401
+ },
402
+ {
403
+ "name": "stackexchange_Title_Answer/judaism.stackexchange.com.jsonl.gz",
404
+ "lines": 26085,
405
+ "weight": 6
406
+ },
407
+ {
408
+ "name": "stackexchange_Title_Answer/ethereum.stackexchange.com.jsonl.gz",
409
+ "lines": 26124,
410
+ "weight": 6
411
+ },
412
+ {
413
+ "name": "stackexchange_Title_Answer/worldbuilding.stackexchange.com.jsonl.gz",
414
+ "lines": 26210,
415
+ "weight": 6
416
+ },
417
+ {
418
+ "name": "stackexchange_title_body/worldbuilding.stackexchange.com.jsonl.gz",
419
+ "lines": 26763,
420
+ "weight": 7
421
+ },
422
+ {
423
+ "name": "stackexchange_Title_Answer/chemistry.stackexchange.com.jsonl.gz",
424
+ "lines": 27061,
425
+ "weight": 7
426
+ },
427
+ {
428
+ "name": "stackexchange_title_body/datascience.stackexchange.com.jsonl.gz",
429
+ "lines": 27397,
430
+ "weight": 7
431
+ },
432
+ {
433
+ "name": "stackexchange_Title_Answer/graphicdesign.stackexchange.com.jsonl.gz",
434
+ "lines": 28083,
435
+ "weight": 7
436
+ },
437
+ {
438
+ "name": "stackexchange_Title_Answer/ux.stackexchange.com.jsonl.gz",
439
+ "lines": 28901,
440
+ "weight": 7
441
+ },
442
+ {
443
+ "name": "stackexchange_title_body/ux.stackexchange.com.jsonl.gz",
444
+ "lines": 29403,
445
+ "weight": 7
446
+ },
447
+ {
448
+ "name": "stackexchange_Title_Answer/money.stackexchange.com.jsonl.gz",
449
+ "lines": 29404,
450
+ "weight": 7
451
+ },
452
+ {
453
+ "name": "stackexchange_title_body/webapps.stackexchange.com.jsonl.gz",
454
+ "lines": 29697,
455
+ "weight": 7
456
+ },
457
+ {
458
+ "name": "stackexchange_Title_Answer/cs.stackexchange.com.jsonl.gz",
459
+ "lines": 30010,
460
+ "weight": 7
461
+ },
462
+ {
463
+ "name": "stackexchange_title_body/graphicdesign.stackexchange.com.jsonl.gz",
464
+ "lines": 30233,
465
+ "weight": 7
466
+ },
467
+ {
468
+ "name": "stackexchange_Title_Answer/webmasters.stackexchange.com.jsonl.gz",
469
+ "lines": 30370,
470
+ "weight": 7
471
+ },
472
+ {
473
+ "name": "stackexchange_title_body/raspberrypi.stackexchange.com.jsonl.gz",
474
+ "lines": 30625,
475
+ "weight": 7
476
+ },
477
+ {
478
+ "name": "stackexchange_title_body/money.stackexchange.com.jsonl.gz",
479
+ "lines": 32021,
480
+ "weight": 8
481
+ },
482
+ {
483
+ "name": "stackexchange_title_body/judaism.stackexchange.com.jsonl.gz",
484
+ "lines": 32028,
485
+ "weight": 8
486
+ },
487
+ {
488
+ "name": "stackexchange_Title_Answer/academia.stackexchange.com.jsonl.gz",
489
+ "lines": 32137,
490
+ "weight": 8
491
+ },
492
+ {
493
+ "name": "stackexchange_title_body/ethereum.stackexchange.com.jsonl.gz",
494
+ "lines": 32760,
495
+ "weight": 8
496
+ },
497
+ {
498
+ "name": "stackexchange_title_body/academia.stackexchange.com.jsonl.gz",
499
+ "lines": 34331,
500
+ "weight": 8
501
+ },
502
+ {
503
+ "name": "stackexchange_title_body/chemistry.stackexchange.com.jsonl.gz",
504
+ "lines": 34506,
505
+ "weight": 8
506
+ },
507
+ {
508
+ "name": "stackexchange_title_body/webmasters.stackexchange.com.jsonl.gz",
509
+ "lines": 34559,
510
+ "weight": 8
511
+ },
512
+ {
513
+ "name": "stackexchange_title_body/meta.stackoverflow.com.jsonl.gz",
514
+ "lines": 36456,
515
+ "weight": 9
516
+ },
517
+ {
518
+ "name": "stackexchange_Title_Answer/travel.stackexchange.com.jsonl.gz",
519
+ "lines": 36533,
520
+ "weight": 9
521
+ },
522
+ {
523
+ "name": "stackexchange_Title_Answer/android.stackexchange.com.jsonl.gz",
524
+ "lines": 38077,
525
+ "weight": 9
526
+ },
527
+ {
528
+ "name": "stackexchange_title_body/cs.stackexchange.com.jsonl.gz",
529
+ "lines": 38314,
530
+ "weight": 9
531
+ },
532
+ {
533
+ "name": "stackexchange_Title_Answer/gamedev.stackexchange.com.jsonl.gz",
534
+ "lines": 40154,
535
+ "weight": 10
536
+ },
537
+ {
538
+ "name": "stackexchange_Title_Answer/rpg.stackexchange.com.jsonl.gz",
539
+ "lines": 40435,
540
+ "weight": 10
541
+ },
542
+ {
543
+ "name": "stackexchange_title_body/travel.stackexchange.com.jsonl.gz",
544
+ "lines": 41227,
545
+ "weight": 10
546
+ },
547
+ {
548
+ "name": "stackexchange_Title_Answer/codereview.stackexchange.com.jsonl.gz",
549
+ "lines": 41748,
550
+ "weight": 10
551
+ },
552
+ {
553
+ "name": "stackexchange_title_body/rpg.stackexchange.com.jsonl.gz",
554
+ "lines": 42303,
555
+ "weight": 10
556
+ },
557
+ {
558
+ "name": "stackexchange_title_body/codereview.stackexchange.com.jsonl.gz",
559
+ "lines": 45765,
560
+ "weight": 11
561
+ },
562
+ {
563
+ "name": "stackexchange_title_body/gamedev.stackexchange.com.jsonl.gz",
564
+ "lines": 46485,
565
+ "weight": 11
566
+ },
567
+ {
568
+ "name": "stackexchange_Title_Answer/softwareengineering.stackexchange.com.jsonl.gz",
569
+ "lines": 51326,
570
+ "weight": 12
571
+ },
572
+ {
573
+ "name": "stackexchange_Title_Answer/security.stackexchange.com.jsonl.gz",
574
+ "lines": 51355,
575
+ "weight": 12
576
+ },
577
+ {
578
+ "name": "stackexchange_title_body/android.stackexchange.com.jsonl.gz",
579
+ "lines": 51608,
580
+ "weight": 12
581
+ },
582
+ {
583
+ "name": "stackexchange_Title_Answer/diy.stackexchange.com.jsonl.gz",
584
+ "lines": 52896,
585
+ "weight": 12
586
+ },
587
+ {
588
+ "name": "stackexchange_title_body/softwareengineering.stackexchange.com.jsonl.gz",
589
+ "lines": 53942,
590
+ "weight": 13
591
+ },
592
+ {
593
+ "name": "stackexchange_Title_Answer/blender.stackexchange.com.jsonl.gz",
594
+ "lines": 54153,
595
+ "weight": 13
596
+ },
597
+ {
598
+ "name": "stackexchange_Title_Answer/scifi.stackexchange.com.jsonl.gz",
599
+ "lines": 54805,
600
+ "weight": 13
601
+ },
602
+ {
603
+ "name": "stackexchange_title_body/security.stackexchange.com.jsonl.gz",
604
+ "lines": 58000,
605
+ "weight": 14
606
+ },
607
+ {
608
+ "name": "stackexchange_Title_Answer/mathematica.stackexchange.com.jsonl.gz",
609
+ "lines": 59895,
610
+ "weight": 14
611
+ },
612
+ {
613
+ "name": "stackexchange_title_body/diy.stackexchange.com.jsonl.gz",
614
+ "lines": 60083,
615
+ "weight": 14
616
+ },
617
+ {
618
+ "name": "stackexchange_Title_Answer/meta.stackexchange.com.jsonl.gz",
619
+ "lines": 60744,
620
+ "weight": 14
621
+ },
622
+ {
623
+ "name": "stackexchange_title_body/scifi.stackexchange.com.jsonl.gz",
624
+ "lines": 61528,
625
+ "weight": 14
626
+ },
627
+ {
628
+ "name": "stackexchange_Title_Answer/drupal.stackexchange.com.jsonl.gz",
629
+ "lines": 67817,
630
+ "weight": 16
631
+ },
632
+ {
633
+ "name": "stackexchange_Title_Answer/dba.stackexchange.com.jsonl.gz",
634
+ "lines": 71449,
635
+ "weight": 17
636
+ },
637
+ {
638
+ "name": "stackexchange_title_body/mathematica.stackexchange.com.jsonl.gz",
639
+ "lines": 73131,
640
+ "weight": 17
641
+ },
642
+ {
643
+ "name": "stackexchange_Title_Answer/ell.stackexchange.com.jsonl.gz",
644
+ "lines": 77892,
645
+ "weight": 18
646
+ },
647
+ {
648
+ "name": "stackexchange_Title_Answer/magento.stackexchange.com.jsonl.gz",
649
+ "lines": 79241,
650
+ "weight": 18
651
+ },
652
+ {
653
+ "name": "stackexchange_title_body/drupal.stackexchange.com.jsonl.gz",
654
+ "lines": 79717,
655
+ "weight": 18
656
+ },
657
+ {
658
+ "name": "stackexchange_Title_Answer/sharepoint.stackexchange.com.jsonl.gz",
659
+ "lines": 80420,
660
+ "weight": 19
661
+ },
662
+ {
663
+ "name": "stackexchange_title_body/blender.stackexchange.com.jsonl.gz",
664
+ "lines": 80766,
665
+ "weight": 19
666
+ },
667
+ {
668
+ "name": "stackexchange_title_body/dba.stackexchange.com.jsonl.gz",
669
+ "lines": 81871,
670
+ "weight": 19
671
+ },
672
+ {
673
+ "name": "stackexchange_Title_Answer/gaming.stackexchange.com.jsonl.gz",
674
+ "lines": 82887,
675
+ "weight": 19
676
+ },
677
+ {
678
+ "name": "stackexchange_title_body/ell.stackexchange.com.jsonl.gz",
679
+ "lines": 83271,
680
+ "weight": 19
681
+ },
682
+ {
683
+ "name": "stackexchange_title_body/meta.stackexchange.com.jsonl.gz",
684
+ "lines": 83510,
685
+ "weight": 19
686
+ },
687
+ {
688
+ "name": "stackexchange_Title_Answer/wordpress.stackexchange.com.jsonl.gz",
689
+ "lines": 83621,
690
+ "weight": 19
691
+ },
692
+ {
693
+ "name": "stackexchange_Title_Answer/mathoverflow.net.jsonl.gz",
694
+ "lines": 85289,
695
+ "weight": 20
696
+ },
697
+ {
698
+ "name": "stackexchange_Title_Answer/salesforce.stackexchange.com.jsonl.gz",
699
+ "lines": 87272,
700
+ "weight": 20
701
+ },
702
+ {
703
+ "name": "stackexchange_title_body/gaming.stackexchange.com.jsonl.gz",
704
+ "lines": 88912,
705
+ "weight": 21
706
+ },
707
+ {
708
+ "name": "stackexchange_Title_Answer/apple.stackexchange.com.jsonl.gz",
709
+ "lines": 92487,
710
+ "weight": 21
711
+ },
712
+ {
713
+ "name": "stackexchange_title_body/sharepoint.stackexchange.com.jsonl.gz",
714
+ "lines": 94011,
715
+ "weight": 22
716
+ },
717
+ {
718
+ "name": "stackexchange_title_body/magento.stackexchange.com.jsonl.gz",
719
+ "lines": 99991,
720
+ "weight": 23
721
+ },
722
+ {
723
+ "name": "stackexchange_Title_Answer/gis.stackexchange.com.jsonl.gz",
724
+ "lines": 100254,
725
+ "weight": 23
726
+ },
727
+ {
728
+ "name": "stackexchange_title_body/wordpress.stackexchange.com.jsonl.gz",
729
+ "lines": 100474,
730
+ "weight": 23
731
+ },
732
+ {
733
+ "name": "stackexchange_Title_Answer/english.stackexchange.com.jsonl.gz",
734
+ "lines": 100640,
735
+ "weight": 23
736
+ },
737
+ {
738
+ "name": "stackexchange_title_body/salesforce.stackexchange.com.jsonl.gz",
739
+ "lines": 105260,
740
+ "weight": 24
741
+ },
742
+ {
743
+ "name": "stackexchange_title_body/english.stackexchange.com.jsonl.gz",
744
+ "lines": 109522,
745
+ "weight": 25
746
+ },
747
+ {
748
+ "name": "stackexchange_title_body/apple.stackexchange.com.jsonl.gz",
749
+ "lines": 110622,
750
+ "weight": 25
751
+ },
752
+ {
753
+ "name": "stackexchange_Title_Answer/stats.stackexchange.com.jsonl.gz",
754
+ "lines": 115679,
755
+ "weight": 27
756
+ },
757
+ {
758
+ "name": "stackexchange_title_body/mathoverflow.net.jsonl.gz",
759
+ "lines": 120851,
760
+ "weight": 28
761
+ },
762
+ {
763
+ "name": "stackexchange_Title_Answer/electronics.stackexchange.com.jsonl.gz",
764
+ "lines": 129494,
765
+ "weight": 30
766
+ },
767
+ {
768
+ "name": "stackexchange_title_body/gis.stackexchange.com.jsonl.gz",
769
+ "lines": 131000,
770
+ "weight": 30
771
+ },
772
+ {
773
+ "name": "stackexchange_Title_Answer/physics.stackexchange.com.jsonl.gz",
774
+ "lines": 141230,
775
+ "weight": 32
776
+ },
777
+ {
778
+ "name": "stackexchange_title_body/electronics.stackexchange.com.jsonl.gz",
779
+ "lines": 143582,
780
+ "weight": 33
781
+ },
782
+ {
783
+ "name": "TriviaQA_pairs.jsonl.gz",
784
+ "lines": 73346,
785
+ "weight": 34
786
+ },
787
+ {
788
+ "name": "stackexchange_Title_Answer/unix.stackexchange.com.jsonl.gz",
789
+ "lines": 155414,
790
+ "weight": 36
791
+ },
792
+ {
793
+ "name": "stackexchange_Title_Answer/tex.stackexchange.com.jsonl.gz",
794
+ "lines": 171628,
795
+ "weight": 39
796
+ },
797
+ {
798
+ "name": "squad_pairs.jsonl.gz",
799
+ "lines": 87599,
800
+ "weight": 40
801
+ },
802
+ {
803
+ "name": "stackexchange_title_body/physics.stackexchange.com.jsonl.gz",
804
+ "lines": 173307,
805
+ "weight": 40
806
+ },
807
+ {
808
+ "name": "stackexchange_title_body/stats.stackexchange.com.jsonl.gz",
809
+ "lines": 173466,
810
+ "weight": 40
811
+ },
812
+ {
813
+ "name": "stackexchange_title_body/unix.stackexchange.com.jsonl.gz",
814
+ "lines": 185997,
815
+ "weight": 42
816
+ },
817
+ {
818
+ "name": "NQ-train_pairs.jsonl.gz",
819
+ "lines": 100231,
820
+ "weight": 46
821
+ },
822
+ {
823
+ "name": "stackexchange_title_body/tex.stackexchange.com.jsonl.gz",
824
+ "lines": 202954,
825
+ "weight": 46
826
+ },
827
+ {
828
+ "name": "quora_duplicates_triplets.jsonl.gz",
829
+ "lines": 103663,
830
+ "weight": 47
831
+ },
832
+ {
833
+ "name": "stackexchange_Title_Answer/serverfault.com.jsonl.gz",
834
+ "lines": 238507,
835
+ "weight": 54
836
+ },
837
+ {
838
+ "name": "stackexchange_Title_Answer/askubuntu.com.jsonl.gz",
839
+ "lines": 267135,
840
+ "weight": 61
841
+ },
842
+ {
843
+ "name": "stackexchange_title_body/serverfault.com.jsonl.gz",
844
+ "lines": 270904,
845
+ "weight": 62
846
+ },
847
+ {
848
+ "name": "stackexchange_duplicate_questions_title_title.jsonl.gz",
849
+ "lines": 304525,
850
+ "weight": 69
851
+ },
852
+ {
853
+ "name": "stackexchange_title_body/askubuntu.com.jsonl.gz",
854
+ "lines": 347925,
855
+ "weight": 79
856
+ },
857
+ {
858
+ "name": "stackexchange_Title_Answer/superuser.com.jsonl.gz",
859
+ "lines": 352610,
860
+ "weight": 80
861
+ },
862
+ {
863
+ "name": "stackexchange_title_body/superuser.com.jsonl.gz",
864
+ "lines": 435463,
865
+ "weight": 99
866
+ },
867
+ {
868
+ "name": "stackexchange_title_body/small_stackexchanges.jsonl.gz",
869
+ "lines": 448146,
870
+ "weight": 102
871
+ },
872
+ {
873
+ "name": "stackexchange_Title_Answer/small_stackexchanges.jsonl.gz",
874
+ "lines": 460256,
875
+ "weight": 104
876
+ },
877
+ {
878
+ "name": "eli5_question_answer.jsonl.gz",
879
+ "lines": 325475,
880
+ "weight": 147
881
+ },
882
+ {
883
+ "name": "yahoo_answers_title_question.jsonl.gz",
884
+ "lines": 659896,
885
+ "weight": 149
886
+ },
887
+ {
888
+ "name": "PAQ_pairs.jsonl.gz",
889
+ "lines": 64371441,
890
+ "weight": 150
891
+ },
892
+ {
893
+ "name": "WikiAnswers_pairs.jsonl.gz",
894
+ "lines": 77427422,
895
+ "weight": 150
896
+ },
897
+ {
898
+ "name": "stackexchange_Title_Answer/math.stackexchange.com.jsonl.gz",
899
+ "lines": 1100953,
900
+ "weight": 226
901
+ },
902
+ {
903
+ "name": "yahoo_answers_title_answer.jsonl.gz",
904
+ "lines": 1198260,
905
+ "weight": 226
906
+ },
907
+ {
908
+ "name": "stackexchange_title_body/math.stackexchange.com.jsonl.gz",
909
+ "lines": 1338443,
910
+ "weight": 226
911
+ },
912
+ {
913
+ "name": "stackexchange_Title_Answer/stackoverflow.com-Posts.jsonl.gz",
914
+ "lines": 15768211,
915
+ "weight": 226
916
+ },
917
+ {
918
+ "name": "stackexchange_title_body/stackoverflow.com-Posts.jsonl.gz",
919
+ "lines": 18562443,
920
+ "weight": 226
921
+ },
922
+ {
923
+ "name": "searchQA_question_top5_snippets_merged.jsonl.gz",
924
+ "lines": 582261,
925
+ "weight": 263
926
+ },
927
+ {
928
+ "name": "amazon-qa-train-pairs.jsonl.gz",
929
+ "lines": 2448839,
930
+ "weight": 451
931
+ },
932
+ {
933
+ "name": "gooaq_pairs.jsonl.gz",
934
+ "lines": 3012496,
935
+ "weight": 451
936
+ },
937
+ {
938
+ "name": "msmarco-query_passage_negative_v2.jsonl.gz",
939
+ "lines": 17579773,
940
+ "weight": 1000
941
+ }
942
+ ]
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7aeda1645569db9e276ed29e1dc0d1df086f30f5e2b581bf57fec097843d0196
3
+ size 265482105
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ {
2
+ "max_seq_length": 512,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "distilbert-base-uncased", "tokenizer_class": "DistilBertTokenizer"}
train_script.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, args):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ assert args.pooling in ['mean', 'cls']
45
+
46
+ self.model = AutoModel.from_pretrained(model_name)
47
+ self.normalize = not args.no_normalize
48
+ self.tokenizer = tokenizer
49
+ self.pooling = args.pooling
50
+
51
+ def forward(self, **kwargs):
52
+ model_output = self.model(**kwargs)
53
+ if self.pooling == 'mean':
54
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
55
+ elif self.pooling == 'cls':
56
+ embeddings = self.cls_pooling(model_output, kwargs['attention_mask'])
57
+
58
+ if self.normalize:
59
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
60
+
61
+ return embeddings
62
+
63
+ def mean_pooling(self, model_output, attention_mask):
64
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
65
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
66
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
67
+
68
+ def cls_pooling(self, model_output, attention_mask):
69
+ return model_output[0][:,0]
70
+
71
+ def save_pretrained(self, output_path):
72
+ if xm.is_master_ordinal():
73
+ self.tokenizer.save_pretrained(output_path)
74
+ self.model.config.save_pretrained(output_path)
75
+
76
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
77
+
78
+
79
+
80
+
81
+ def train_function(index, args, queue):
82
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
83
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer, args)
84
+
85
+
86
+ ### Train Loop
87
+ device = xm.xla_device()
88
+ model = model.to(device)
89
+
90
+ # Instantiate optimizer
91
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
92
+
93
+ lr_scheduler = get_linear_schedule_with_warmup(
94
+ optimizer=optimizer,
95
+ num_warmup_steps=500,
96
+ num_training_steps=args.steps,
97
+ )
98
+
99
+ # Now we train the model
100
+ cross_entropy_loss = nn.CrossEntropyLoss()
101
+ max_grad_norm = 1
102
+
103
+ model.train()
104
+
105
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
106
+ #### Get the batch data
107
+ batch = queue.get()
108
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
109
+
110
+
111
+ if len(batch[0]) == 2: #(anchor, positive)
112
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length_a, truncation=True, padding="max_length")
113
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
114
+
115
+ ### Compute embeddings
116
+ embeddings_a = model(**text1.to(device))
117
+ embeddings_b = model(**text2.to(device))
118
+
119
+ ### Gather all embedings
120
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
121
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
122
+
123
+ ### Compute similarity scores 512 x 512
124
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
125
+
126
+ ### Compute cross-entropy loss
127
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
128
+
129
+ ## Symmetric loss as in CLIP
130
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
131
+
132
+ else: #(anchor, positive, negative)
133
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length_a, truncation=True, padding="max_length")
134
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
135
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
136
+
137
+ embeddings_a = model(**text1.to(device))
138
+ embeddings_b1 = model(**text2.to(device))
139
+ embeddings_b2 = model(**text3.to(device))
140
+
141
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
142
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
143
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
144
+
145
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
146
+
147
+ ### Compute similarity scores 512 x 1024
148
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
149
+
150
+ ### Compute cross-entropy loss
151
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
152
+
153
+ ## One-way loss
154
+ loss = cross_entropy_loss(scores, labels)
155
+
156
+
157
+ # Backward pass
158
+ optimizer.zero_grad()
159
+ loss.backward()
160
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
161
+
162
+ xm.optimizer_step(optimizer, barrier=True)
163
+ lr_scheduler.step()
164
+
165
+
166
+ #Save model
167
+ if (global_step+1) % args.save_steps == 0:
168
+ output_path = os.path.join(args.output, str(global_step+1))
169
+ xm.master_print("save model: "+output_path)
170
+ model.save_pretrained(output_path)
171
+
172
+
173
+ output_path = os.path.join(args.output, "final")
174
+ xm.master_print("save model final: "+ output_path)
175
+ model.save_pretrained(output_path)
176
+
177
+
178
+ def produce_data(args, queue, filepaths, dataset_indices):
179
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
180
+ size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
181
+ num_same_dataset = int(size_per_dataset / args.batch_size)
182
+ print("producer", "global_batch_size", global_batch_size)
183
+ print("producer", "size_per_dataset", size_per_dataset)
184
+ print("producer", "num_same_dataset", num_same_dataset)
185
+
186
+ datasets = []
187
+ for filepath in filepaths:
188
+ if "reddit_" in filepath: #Special dataset class for Reddit files
189
+ data_obj = RedditDataset(filepath)
190
+ else:
191
+ data_obj = Dataset(filepath)
192
+ datasets.append(iter(data_obj))
193
+
194
+ # Store if dataset is in a 2 col or 3 col format
195
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
196
+
197
+ while True:
198
+ texts_in_batch = set()
199
+ batch_format = None #2 vs 3 col format for this batch
200
+
201
+ #Add data from several sub datasets
202
+ for _ in range(args.datasets_per_batch):
203
+ valid_dataset = False #Check that datasets have the same 2/3 col format
204
+ while not valid_dataset:
205
+ data_idx = random.choice(dataset_indices)
206
+ if batch_format is None:
207
+ batch_format = num_cols[data_idx]
208
+ valid_dataset = True
209
+ else: #Check that this dataset has the same format
210
+ valid_dataset = (batch_format == num_cols[data_idx])
211
+
212
+ #Get data from this dataset
213
+ dataset = datasets[data_idx]
214
+ for _ in range(num_same_dataset):
215
+ for _ in range(args.nprocs):
216
+ batch_device = [] #A batch for one device
217
+ while len(batch_device) < args.batch_size:
218
+ sample = next(dataset)
219
+ in_batch = False
220
+ for text in sample:
221
+ if text in texts_in_batch:
222
+ in_batch = True
223
+ break
224
+
225
+ if not in_batch:
226
+ for text in sample:
227
+ texts_in_batch.add(text)
228
+ batch_device.append(sample)
229
+
230
+ queue.put(batch_device)
231
+
232
+
233
+ class RedditDataset:
234
+ """
235
+ A class that handles the reddit data files
236
+ """
237
+ def __init__(self, filepath):
238
+ self.filepath = filepath
239
+
240
+ def __iter__(self):
241
+ while True:
242
+ with gzip.open(self.filepath, "rt") as fIn:
243
+ for line in fIn:
244
+ data = json.loads(line)
245
+
246
+ if "response" in data and "context" in data:
247
+ yield [data["response"], data["context"]]
248
+
249
+ class Dataset:
250
+ """
251
+ A class that handles one dataset
252
+ """
253
+ def __init__(self, filepath):
254
+ self.filepath = filepath
255
+
256
+ def __iter__(self):
257
+ max_dataset_size = 20*1000*1000 #Cache small datasets in memory
258
+ dataset = []
259
+ data_format = None
260
+
261
+ while dataset is None or len(dataset) == 0:
262
+ with gzip.open(self.filepath, "rt") as fIn:
263
+ for line in fIn:
264
+ data = json.loads(line)
265
+ if isinstance(data, dict):
266
+ data = data['texts']
267
+
268
+ if data_format is None:
269
+ data_format = len(data)
270
+
271
+ #Ensure that all entries are of the same 2/3 col format
272
+ assert len(data) == data_format
273
+
274
+ if dataset is not None:
275
+ dataset.append(data)
276
+ if len(dataset) >= max_dataset_size:
277
+ dataset = None
278
+
279
+ yield data
280
+
281
+ # Data loaded. Now stream to the queue
282
+ # Shuffle for each epoch
283
+ while True:
284
+ random.shuffle(dataset)
285
+ for data in dataset:
286
+ yield data
287
+
288
+
289
+
290
+ if __name__ == "__main__":
291
+ parser = argparse.ArgumentParser()
292
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
293
+ parser.add_argument('--steps', type=int, default=2000)
294
+ parser.add_argument('--save_steps', type=int, default=10000)
295
+ parser.add_argument('--batch_size', type=int, default=64)
296
+ parser.add_argument('--max_length_a', type=int, default=128)
297
+ parser.add_argument('--max_length_b', type=int, default=128)
298
+ parser.add_argument('--nprocs', type=int, default=8)
299
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
300
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
301
+ parser.add_argument('--no_normalize', action="store_true", default=False, help="If set: Embeddings are not normalized")
302
+ parser.add_argument('--pooling', default='mean')
303
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
304
+ parser.add_argument('data_config', help="A data_config.json file")
305
+ parser.add_argument('output')
306
+ args = parser.parse_args()
307
+
308
+ # Ensure global batch size is divisble by data_sample_size
309
+ assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
310
+
311
+ logging.info("Output: "+args.output)
312
+ if os.path.exists(args.output):
313
+ print("Output folder already exists.")
314
+ input("Continue?")
315
+
316
+ # Write train script to output path
317
+ os.makedirs(args.output, exist_ok=True)
318
+
319
+ data_config_path = os.path.join(args.output, 'data_config.json')
320
+ copyfile(args.data_config, data_config_path)
321
+
322
+ train_script_path = os.path.join(args.output, 'train_script.py')
323
+ copyfile(__file__, train_script_path)
324
+ with open(train_script_path, 'a') as fOut:
325
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
326
+
327
+
328
+
329
+ #Load data config
330
+ with open(args.data_config) as fIn:
331
+ data_config = json.load(fIn)
332
+
333
+ queue = mp.Queue(maxsize=100*args.nprocs)
334
+
335
+ filepaths = []
336
+ dataset_indices = []
337
+ for idx, data in enumerate(data_config):
338
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
339
+ dataset_indices.extend([idx]*data['weight'])
340
+
341
+ # Start producer
342
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
343
+ p.start()
344
+
345
+ # Run training
346
+ print("Start processes:", args.nprocs)
347
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
348
+ print("Training done")
349
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
350
+ print("With 'pkill python' you can kill all remaining python processes")
351
+ p.kill()
352
+ exit()
353
+
354
+
355
+
356
+ # Script was called via:
357
+ #python train_many_data_files_v2.py --steps 200000 --batch_size 64 --model distilbert-base-uncased --max_length_a 64 --max_length_b 250 --scale 1 --pooling cls --no_normalize train_data_configs/multi-qa_v1.json output/multi-qa_v1-distilbert-base-cls_dot
vocab.txt ADDED
The diff for this file is too large to render. See raw diff