pszemraj commited on
Commit
665558b
1 Parent(s): fb68d3b

add additional 2-epoch checkpoint, better regularization

Browse files
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "pszemraj/led-large-book-summary-11E",
3
  "_num_labels": 3,
4
  "activation_dropout": 0.0,
5
  "activation_function": "gelu",
@@ -52,8 +52,8 @@
52
  "length_penalty": 0.8,
53
  "max_decoder_position_embeddings": 1024,
54
  "max_encoder_position_embeddings": 16384,
55
- "max_length": 512,
56
- "min_length": 32,
57
  "model_type": "led",
58
  "no_repeat_ngram_size": 3,
59
  "num_beams": 4,
@@ -61,6 +61,7 @@
61
  "output_past": false,
62
  "pad_token_id": 1,
63
  "prefix": " ",
 
64
  "torch_dtype": "float32",
65
  "transformers_version": "4.19.2",
66
  "use_cache": false,
1
  {
2
+ "_name_or_path": "pszemraj/led-large-book-summary",
3
  "_num_labels": 3,
4
  "activation_dropout": 0.0,
5
  "activation_function": "gelu",
52
  "length_penalty": 0.8,
53
  "max_decoder_position_embeddings": 1024,
54
  "max_encoder_position_embeddings": 16384,
55
+ "max_length": 1024,
56
+ "min_length": 8,
57
  "model_type": "led",
58
  "no_repeat_ngram_size": 3,
59
  "num_beams": 4,
61
  "output_past": false,
62
  "pad_token_id": 1,
63
  "prefix": " ",
64
+ "repetition_penalty": 3.5,
65
  "torch_dtype": "float32",
66
  "transformers_version": "4.19.2",
67
  "use_cache": false,
latest ADDED
@@ -0,0 +1 @@
 
1
+ global_step296
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cae26c66ffc396d3f1b515359fef3cc2b4a080043dbfa56f7a968450c8ad1b27
3
  size 1839482407
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0603a70f15308ebccb9b66369464a83efb920ed266f736bdca0279bc066eecf7
3
  size 1839482407
rng_state_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97f7901e7b6ddfc02769b63a371fd87e014fae5a4c2ef46f9e10298a4e62e643
3
+ size 14439
tokenizer_config.json CHANGED
@@ -1 +1 @@
1
- {"errors": "replace", "bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": "/root/.cache/huggingface/transformers/2ad921573d53ebf0c0450d63a211e61d8e328324e84830c365abff01f2d115f1.cb2244924ab24d706b02fd7fcedaea4531566537687a539ebb94db511fd122a0", "name_or_path": "pszemraj/led-large-book-summary-11E", "tokenizer_class": "LEDTokenizer"}
1
+ {"errors": "replace", "bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": "/root/.cache/huggingface/transformers/2ad921573d53ebf0c0450d63a211e61d8e328324e84830c365abff01f2d115f1.cb2244924ab24d706b02fd7fcedaea4531566537687a539ebb94db511fd122a0", "name_or_path": "pszemraj/led-large-book-summary", "tokenizer_class": "LEDTokenizer"}
trainer_state.json CHANGED
@@ -1,1537 +1,364 @@
1
  {
2
  "best_metric": null,
3
  "best_model_checkpoint": null,
4
- "epoch": 1.9994069974303224,
5
- "global_step": 1264,
6
  "is_hyper_param_search": false,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [
10
  {
11
- "epoch": 0.01,
12
- "learning_rate": 2.631578947368421e-06,
13
- "loss": 1.0727,
14
  "step": 5
15
  },
16
  {
17
- "epoch": 0.02,
18
- "learning_rate": 5.263157894736842e-06,
19
- "loss": 0.8573,
20
  "step": 10
21
  },
22
  {
23
- "epoch": 0.02,
24
- "learning_rate": 7.894736842105265e-06,
25
- "loss": 1.0991,
26
  "step": 15
27
  },
28
  {
29
- "epoch": 0.03,
30
- "learning_rate": 1.0526315789473684e-05,
31
- "loss": 0.8893,
32
  "step": 20
33
  },
34
  {
35
- "epoch": 0.04,
36
- "learning_rate": 1.3157894736842108e-05,
37
- "loss": 0.9549,
38
  "step": 25
39
  },
40
  {
41
- "epoch": 0.05,
42
- "learning_rate": 1.578947368421053e-05,
43
- "loss": 0.9953,
44
  "step": 30
45
  },
46
  {
47
- "epoch": 0.06,
48
- "learning_rate": 1.8421052631578947e-05,
49
- "loss": 1.0163,
50
  "step": 35
51
  },
52
  {
53
- "epoch": 0.06,
54
- "learning_rate": 1.9999868674866755e-05,
55
- "loss": 0.9562,
56
  "step": 40
57
  },
58
  {
59
- "epoch": 0.07,
60
- "learning_rate": 1.9998391306730024e-05,
61
- "loss": 1.0517,
62
  "step": 45
63
  },
64
  {
65
- "epoch": 0.08,
66
- "learning_rate": 1.9995272657365104e-05,
67
- "loss": 0.9458,
68
  "step": 50
69
  },
70
  {
71
- "epoch": 0.09,
72
- "learning_rate": 1.9990513238712407e-05,
73
- "loss": 0.9163,
74
  "step": 55
75
  },
76
  {
77
- "epoch": 0.09,
78
- "learning_rate": 1.998411383205207e-05,
79
- "loss": 0.8653,
80
  "step": 60
81
  },
82
  {
83
- "epoch": 0.1,
84
- "learning_rate": 1.9976075487875692e-05,
85
- "loss": 1.063,
86
  "step": 65
87
  },
88
  {
89
- "epoch": 0.11,
90
- "learning_rate": 1.9966399525713882e-05,
91
- "loss": 0.8901,
92
  "step": 70
93
  },
94
  {
95
- "epoch": 0.12,
96
- "learning_rate": 1.9955087533919662e-05,
97
- "loss": 1.0556,
98
  "step": 75
99
  },
100
  {
101
- "epoch": 0.13,
102
- "learning_rate": 1.994214136940773e-05,
103
- "loss": 0.9434,
104
  "step": 80
105
  },
106
  {
107
- "epoch": 0.13,
108
- "learning_rate": 1.9927563157349646e-05,
109
- "loss": 0.9989,
110
  "step": 85
111
  },
112
  {
113
- "epoch": 0.14,
114
- "learning_rate": 1.9911355290824955e-05,
115
- "loss": 0.9829,
116
  "step": 90
117
  },
118
  {
119
- "epoch": 0.15,
120
- "learning_rate": 1.9893520430428378e-05,
121
- "loss": 0.9988,
122
  "step": 95
123
  },
124
  {
125
- "epoch": 0.16,
126
- "learning_rate": 1.987406150383304e-05,
127
- "loss": 1.0073,
128
  "step": 100
129
  },
130
  {
131
- "epoch": 0.17,
132
- "learning_rate": 1.98529817053099e-05,
133
- "loss": 0.8863,
134
  "step": 105
135
  },
136
  {
137
- "epoch": 0.17,
138
- "learning_rate": 1.983028449520338e-05,
139
- "loss": 0.8015,
140
  "step": 110
141
  },
142
  {
143
- "epoch": 0.18,
144
- "learning_rate": 1.980597359936335e-05,
145
- "loss": 1.0477,
146
  "step": 115
147
  },
148
  {
149
- "epoch": 0.19,
150
- "learning_rate": 1.9780053008533486e-05,
151
- "loss": 0.961,
152
  "step": 120
153
  },
154
  {
155
- "epoch": 0.2,
156
- "learning_rate": 1.97525269776962e-05,
157
- "loss": 0.9476,
158
  "step": 125
159
  },
160
  {
161
- "epoch": 0.21,
162
- "learning_rate": 1.9723400025374168e-05,
163
- "loss": 1.1,
164
  "step": 130
165
  },
166
  {
167
- "epoch": 0.21,
168
- "learning_rate": 1.969267693288855e-05,
169
- "loss": 0.946,
170
  "step": 135
171
  },
172
  {
173
- "epoch": 0.22,
174
- "learning_rate": 1.9660362743574163e-05,
175
- "loss": 0.946,
176
  "step": 140
177
  },
178
  {
179
- "epoch": 0.23,
180
- "learning_rate": 1.9626462761951583e-05,
181
- "loss": 0.9368,
182
  "step": 145
183
  },
184
  {
185
- "epoch": 0.24,
186
- "learning_rate": 1.959098255285636e-05,
187
- "loss": 1.0338,
188
  "step": 150
189
  },
190
  {
191
- "epoch": 0.25,
192
- "learning_rate": 1.9553927940525557e-05,
193
- "loss": 1.0359,
194
  "step": 155
195
  },
196
  {
197
- "epoch": 0.25,
198
- "learning_rate": 1.9515305007641653e-05,
199
- "loss": 1.0189,
200
  "step": 160
201
  },
202
  {
203
- "epoch": 0.26,
204
- "learning_rate": 1.9475120094334046e-05,
205
- "loss": 1.0388,
206
  "step": 165
207
  },
208
  {
209
- "epoch": 0.27,
210
- "learning_rate": 1.9433379797138314e-05,
211
- "loss": 1.0733,
212
  "step": 170
213
  },
214
  {
215
- "epoch": 0.28,
216
- "learning_rate": 1.939009096791333e-05,
217
- "loss": 0.9984,
218
  "step": 175
219
  },
220
  {
221
- "epoch": 0.28,
222
- "learning_rate": 1.9345260712716517e-05,
223
- "loss": 1.0221,
224
  "step": 180
225
  },
226
  {
227
- "epoch": 0.29,
228
- "learning_rate": 1.9298896390637363e-05,
229
- "loss": 1.2205,
230
  "step": 185
231
  },
232
  {
233
- "epoch": 0.3,
234
- "learning_rate": 1.9251005612589382e-05,
235
- "loss": 0.9938,
236
  "step": 190
237
  },
238
  {
239
- "epoch": 0.31,
240
- "learning_rate": 1.9201596240060737e-05,
241
- "loss": 1.0059,
242
  "step": 195
243
  },
244
  {
245
- "epoch": 0.32,
246
- "learning_rate": 1.9150676383823775e-05,
247
- "loss": 1.0072,
248
  "step": 200
249
  },
250
  {
251
- "epoch": 0.32,
252
- "learning_rate": 1.9098254402603573e-05,
253
- "loss": 0.8886,
254
  "step": 205
255
  },
256
  {
257
- "epoch": 0.33,
258
- "learning_rate": 1.904433890170583e-05,
259
- "loss": 1.1496,
260
  "step": 210
261
  },
262
  {
263
- "epoch": 0.34,
264
- "learning_rate": 1.898893873160428e-05,
265
- "loss": 0.8883,
266
  "step": 215
267
  },
268
  {
269
- "epoch": 0.35,
270
- "learning_rate": 1.893206298648781e-05,
271
- "loss": 0.9564,
272
  "step": 220
273
  },
274
  {
275
- "epoch": 0.36,
276
- "learning_rate": 1.887372100276764e-05,
277
- "loss": 0.9264,
278
  "step": 225
279
  },
280
  {
281
- "epoch": 0.36,
282
- "learning_rate": 1.8813922357544713e-05,
283
- "loss": 1.1487,
284
  "step": 230
285
  },
286
  {
287
- "epoch": 0.37,
288
- "learning_rate": 1.875267686703754e-05,
289
- "loss": 1.0543,
290
  "step": 235
291
  },
292
  {
293
- "epoch": 0.38,
294
- "learning_rate": 1.8689994584970843e-05,
295
- "loss": 1.0098,
296
  "step": 240
297
  },
298
  {
299
- "epoch": 0.39,
300
- "learning_rate": 1.8625885800925193e-05,
301
- "loss": 0.915,
302
  "step": 245
303
  },
304
  {
305
- "epoch": 0.4,
306
- "learning_rate": 1.8560361038647917e-05,
307
- "loss": 1.0092,
308
  "step": 250
309
  },
310
  {
311
- "epoch": 0.4,
312
- "learning_rate": 1.8493431054325583e-05,
313
- "loss": 1.0312,
314
  "step": 255
315
  },
316
  {
317
- "epoch": 0.41,
318
- "learning_rate": 1.8425106834818336e-05,
319
- "loss": 1.0963,
320
  "step": 260
321
  },
322
  {
323
- "epoch": 0.42,
324
- "learning_rate": 1.8355399595856326e-05,
325
- "loss": 0.8076,
326
  "step": 265
327
  },
328
  {
329
- "epoch": 0.43,
330
- "learning_rate": 1.8284320780198624e-05,
331
- "loss": 0.8218,
332
  "step": 270
333
  },
334
  {
335
- "epoch": 0.43,
336
- "learning_rate": 1.8211882055754824e-05,
337
- "loss": 0.9424,
338
  "step": 275
339
  },
340
  {
341
- "epoch": 0.44,
342
- "learning_rate": 1.8138095313669705e-05,
343
- "loss": 1.0962,
344
  "step": 280
345
  },
346
  {
347
- "epoch": 0.45,
348
- "learning_rate": 1.8062972666371258e-05,
349
- "loss": 1.0918,
350
  "step": 285
351
  },
352
  {
353
- "epoch": 0.46,
354
- "learning_rate": 1.798652644558236e-05,
355
- "loss": 1.065,
356
  "step": 290
357
- },
358
- {
359
- "epoch": 0.47,
360
- "learning_rate": 1.790876920029647e-05,
361
- "loss": 0.9168,
362
- "step": 295
363
- },
364
- {
365
- "epoch": 0.47,
366
- "learning_rate": 1.7829713694717665e-05,
367
- "loss": 1.0295,
368
- "step": 300
369
- },
370
- {
371
- "epoch": 0.48,
372
- "learning_rate": 1.774937290616533e-05,
373
- "loss": 0.9455,
374
- "step": 305
375
- },
376
- {
377
- "epoch": 0.49,
378
- "learning_rate": 1.7667760022943864e-05,
379
- "loss": 1.0272,
380
- "step": 310
381
- },
382
- {
383
- "epoch": 0.5,
384
- "learning_rate": 1.7584888442177774e-05,
385
- "loss": 1.127,
386
- "step": 315
387
- },
388
- {
389
- "epoch": 0.51,
390
- "learning_rate": 1.7500771767612473e-05,
391
- "loss": 1.0332,
392
- "step": 320
393
- },
394
- {
395
- "epoch": 0.51,
396
- "learning_rate": 1.7415423807381162e-05,
397
- "loss": 1.0603,
398
- "step": 325
399
- },
400
- {
401
- "epoch": 0.52,
402
- "learning_rate": 1.7328858571738157e-05,
403
- "loss": 1.1075,
404
- "step": 330
405
- },
406
- {
407
- "epoch": 0.53,
408
- "learning_rate": 1.7241090270759055e-05,
409
- "loss": 1.1279,
410
- "step": 335
411
- },
412
- {
413
- "epoch": 0.54,
414
- "learning_rate": 1.715213331200807e-05,
415
- "loss": 1.1037,
416
- "step": 340
417
- },
418
- {
419
- "epoch": 0.55,
420
- "learning_rate": 1.7062002298172984e-05,
421
- "loss": 0.9818,
422
- "step": 345
423
- },
424
- {
425
- "epoch": 0.55,
426
- "learning_rate": 1.697071202466803e-05,
427
- "loss": 0.9672,
428
- "step": 350
429
- },
430
- {
431
- "epoch": 0.56,
432
- "learning_rate": 1.687827747720517e-05,
433
- "loss": 0.8006,
434
- "step": 355
435
- },
436
- {
437
- "epoch": 0.57,
438
- "learning_rate": 1.6784713829334124e-05,
439
- "loss": 0.9823,
440
- "step": 360
441
- },
442
- {
443
- "epoch": 0.58,
444
- "learning_rate": 1.6690036439951552e-05,
445
- "loss": 1.0844,
446
- "step": 365
447
- },
448
- {
449
- "epoch": 0.59,
450
- "learning_rate": 1.6594260850779837e-05,
451
- "loss": 1.0858,
452
- "step": 370
453
- },
454
- {
455
- "epoch": 0.59,
456
- "learning_rate": 1.6497402783815834e-05,
457
- "loss": 0.9117,
458
- "step": 375
459
- },
460
- {
461
- "epoch": 0.6,
462
- "learning_rate": 1.6399478138750015e-05,
463
- "loss": 0.9313,
464
- "step": 380
465
- },
466
- {
467
- "epoch": 0.61,
468
- "learning_rate": 1.63005029903565e-05,
469
- "loss": 0.9776,
470
- "step": 385
471
- },
472
- {
473
- "epoch": 0.62,
474
- "learning_rate": 1.620049358585427e-05,
475
- "loss": 1.1445,
476
- "step": 390
477
- },
478
- {
479
- "epoch": 0.62,
480
- "learning_rate": 1.609946634224015e-05,
481
- "loss": 1.0567,
482
- "step": 395
483
- },
484
- {
485
- "epoch": 0.63,
486
- "learning_rate": 1.5997437843593856e-05,
487
- "loss": 0.9138,
488
- "step": 400
489
- },
490
- {
491
- "epoch": 0.64,
492
- "learning_rate": 1.5894424838355654e-05,
493
- "loss": 1.1003,
494
- "step": 405
495
- },
496
- {
497
- "epoch": 0.65,
498
- "learning_rate": 1.5790444236577028e-05,
499
- "loss": 1.0624,
500
- "step": 410
501
- },
502
- {
503
- "epoch": 0.66,
504
- "learning_rate": 1.568551310714482e-05,
505
- "loss": 0.8902,
506
- "step": 415
507
- },
508
- {
509
- "epoch": 0.66,
510
- "learning_rate": 1.557964867497929e-05,
511
- "loss": 1.0329,
512
- "step": 420
513
- },
514
- {
515
- "epoch": 0.67,
516
- "learning_rate": 1.5472868318206566e-05,
517
- "loss": 1.0023,
518
- "step": 425
519
- },
520
- {
521
- "epoch": 0.68,
522
- "learning_rate": 1.5365189565305957e-05,
523
- "loss": 0.9393,
524
- "step": 430
525
- },
526
- {
527
- "epoch": 0.69,
528
- "learning_rate": 1.5256630092232567e-05,
529
- "loss": 0.9747,
530
- "step": 435
531
- },
532
- {
533
- "epoch": 0.7,
534
- "learning_rate": 1.5147207719515692e-05,
535
- "loss": 1.0078,
536
- "step": 440
537
- },
538
- {
539
- "epoch": 0.7,
540
- "learning_rate": 1.5036940409333533e-05,
541
- "loss": 0.9576,
542
- "step": 445
543
- },
544
- {
545
- "epoch": 0.71,
546
- "learning_rate": 1.4925846262564592e-05,
547
- "loss": 0.992,
548
- "step": 450
549
- },
550
- {
551
- "epoch": 0.72,
552
- "learning_rate": 1.4813943515816344e-05,
553
- "loss": 1.0192,
554
- "step": 455
555
- },
556
- {
557
- "epoch": 0.73,
558
- "learning_rate": 1.4701250538431617e-05,
559
- "loss": 1.0154,
560
- "step": 460
561
- },
562
- {
563
- "epoch": 0.74,
564
- "learning_rate": 1.4587785829473173e-05,
565
- "loss": 1.1043,
566
- "step": 465
567
- },
568
- {
569
- "epoch": 0.74,
570
- "learning_rate": 1.4473568014687018e-05,
571
- "loss": 0.8351,
572
- "step": 470
573
- },
574
- {
575
- "epoch": 0.75,
576
- "learning_rate": 1.4358615843444876e-05,
577
- "loss": 1.0276,
578
- "step": 475
579
- },
580
- {
581
- "epoch": 0.76,
582
- "learning_rate": 1.4242948185666419e-05,
583
- "loss": 0.9423,
584
- "step": 480
585
- },
586
- {
587
- "epoch": 0.77,
588
- "learning_rate": 1.4126584028721677e-05,
589
- "loss": 0.9598,
590
- "step": 485
591
- },
592
- {
593
- "epoch": 0.77,
594
- "learning_rate": 1.4009542474314173e-05,
595
- "loss": 1.1755,
596
- "step": 490
597
- },
598
- {
599
- "epoch": 0.78,
600
- "learning_rate": 1.3891842735345285e-05,
601
- "loss": 1.1018,
602
- "step": 495
603
- },
604
- {
605
- "epoch": 0.79,
606
- "learning_rate": 1.3773504132760379e-05,
607
- "loss": 1.0445,
608
- "step": 500
609
- },
610
- {
611
- "epoch": 0.8,
612
- "learning_rate": 1.3654546092377166e-05,
613
- "loss": 0.9674,
614
- "step": 505
615
- },
616
- {
617
- "epoch": 0.81,
618
- "learning_rate": 1.3534988141696891e-05,
619
- "loss": 1.0473,
620
- "step": 510
621
- },
622
- {
623
- "epoch": 0.81,
624
- "learning_rate": 1.3414849906698788e-05,
625
- "loss": 0.9346,
626
- "step": 515
627
- },
628
- {
629
- "epoch": 0.82,
630
- "learning_rate": 1.3294151108618379e-05,
631
- "loss": 0.9643,
632
- "step": 520
633
- },
634
- {
635
- "epoch": 0.83,
636
- "learning_rate": 1.3172911560710167e-05,
637
- "loss": 1.113,
638
- "step": 525
639
- },
640
- {
641
- "epoch": 0.84,
642
- "learning_rate": 1.3051151164995188e-05,
643
- "loss": 1.0155,
644
- "step": 530
645
- },
646
- {
647
- "epoch": 0.85,
648
- "learning_rate": 1.2928889908994003e-05,
649
- "loss": 0.9675,
650
- "step": 535
651
- },
652
- {
653
- "epoch": 0.85,
654
- "learning_rate": 1.280614786244566e-05,
655
- "loss": 0.8424,
656
- "step": 540
657
- },
658
- {
659
- "epoch": 0.86,
660
- "learning_rate": 1.2682945174013148e-05,
661
- "loss": 1.1247,
662
- "step": 545
663
- },
664
- {
665
- "epoch": 0.87,
666
- "learning_rate": 1.2559302067975914e-05,
667
- "loss": 0.9641,
668
- "step": 550
669
- },
670
- {
671
- "epoch": 0.88,
672
- "learning_rate": 1.243523884090995e-05,
673
- "loss": 1.0329,
674
- "step": 555
675
- },
676
- {
677
- "epoch": 0.89,
678
- "learning_rate": 1.2310775858356017e-05,
679
- "loss": 1.0533,
680
- "step": 560
681
- },
682
- {
683
- "epoch": 0.89,
684
- "learning_rate": 1.2185933551476545e-05,
685
- "loss": 1.0216,
686
- "step": 565
687
- },
688
- {
689
- "epoch": 0.9,
690
- "learning_rate": 1.2060732413701773e-05,
691
- "loss": 1.0346,
692
- "step": 570
693
- },
694
- {
695
- "epoch": 0.91,
696
- "learning_rate": 1.1935192997365666e-05,
697
- "loss": 0.9749,
698
- "step": 575
699
- },
700
- {
701
- "epoch": 0.92,
702
- "learning_rate": 1.1809335910332136e-05,
703
- "loss": 1.0901,
704
- "step": 580
705
- },
706
- {
707
- "epoch": 0.93,
708
- "learning_rate": 1.1683181812612186e-05,
709
- "loss": 1.0043,
710
- "step": 585
711
- },
712
- {
713
- "epoch": 0.93,
714
- "learning_rate": 1.1556751412972462e-05,
715
- "loss": 0.9162,
716
- "step": 590
717
- },
718
- {
719
- "epoch": 0.94,
720
- "learning_rate": 1.1430065465535827e-05,
721
- "loss": 0.9228,
722
- "step": 595
723
- },
724
- {
725
- "epoch": 0.95,
726
- "learning_rate": 1.1303144766374476e-05,
727
- "loss": 0.9568,
728
- "step": 600
729
- },
730
- {
731
- "epoch": 0.96,
732
- "learning_rate": 1.1176010150096158e-05,
733
- "loss": 1.063,
734
- "step": 605
735
- },
736
- {
737
- "epoch": 0.96,
738
- "learning_rate": 1.104868248642408e-05,
739
- "loss": 0.9938,
740
- "step": 610
741
- },
742
- {
743
- "epoch": 0.97,
744
- "learning_rate": 1.092118267677106e-05,
745
- "loss": 1.1056,
746
- "step": 615
747
- },
748
- {
749
- "epoch": 0.98,
750
- "learning_rate": 1.0793531650808469e-05,
751
- "loss": 0.9269,
752
- "step": 620
753
- },
754
- {
755
- "epoch": 0.99,
756
- "learning_rate": 1.0665750363030498e-05,
757
- "loss": 1.0452,
758
- "step": 625
759
- },
760
- {
761
- "epoch": 1.0,
762
- "learning_rate": 1.0537859789314424e-05,
763
- "loss": 0.8855,
764
- "step": 630
765
- },
766
- {
767
- "epoch": 1.0,
768
- "learning_rate": 1.0409880923477293e-05,
769
- "loss": 1.0583,
770
- "step": 635
771
- },
772
- {
773
- "epoch": 1.01,
774
- "learning_rate": 1.028183477382971e-05,
775
- "loss": 0.8533,
776
- "step": 640
777
- },
778
- {
779
- "epoch": 1.02,
780
- "learning_rate": 1.0153742359727226e-05,
781
- "loss": 0.942,
782
- "step": 645
783
- },
784
- {
785
- "epoch": 1.03,
786
- "learning_rate": 1.0025624708119901e-05,
787
- "loss": 1.0044,
788
- "step": 650
789
- },
790
- {
791
- "epoch": 1.04,
792
- "learning_rate": 9.897502850100648e-06,
793
- "loss": 0.9652,
794
- "step": 655
795
- },
796
- {
797
- "epoch": 1.04,
798
- "learning_rate": 9.76939781745289e-06,
799
- "loss": 1.0202,
800
- "step": 660
801
- },
802
- {
803
- "epoch": 1.05,
804
- "learning_rate": 9.641330639198083e-06,
805
- "loss": 0.9401,
806
- "step": 665
807
- },
808
- {
809
- "epoch": 1.06,
810
- "learning_rate": 9.513322338143714e-06,
811
- "loss": 0.8343,
812
- "step": 670
813
- },
814
- {
815
- "epoch": 1.07,
816
- "learning_rate": 9.385393927432307e-06,
817
- "loss": 0.8904,
818
- "step": 675
819
- },
820
- {
821
- "epoch": 1.08,
822
- "learning_rate": 9.257566407092032e-06,
823
- "loss": 0.9143,
824
- "step": 680
825
- },
826
- {
827
- "epoch": 1.08,
828
- "learning_rate": 9.129860760589441e-06,
829
- "loss": 0.8408,
830
- "step": 685
831
- },
832
- {
833
- "epoch": 1.09,
834
- "learning_rate": 9.002297951384945e-06,
835
- "loss": 0.9426,
836
- "step": 690
837
- },
838
- {
839
- "epoch": 1.1,
840
- "learning_rate": 8.874898919491564e-06,
841
- "loss": 0.9476,
842
- "step": 695
843
- },
844
- {
845
- "epoch": 1.11,
846
- "learning_rate": 8.74768457803754e-06,
847
- "loss": 0.9992,
848
- "step": 700
849
- },
850
- {
851
- "epoch": 1.12,
852
- "learning_rate": 8.62067580983333e-06,
853
- "loss": 0.8491,
854
- "step": 705
855
- },
856
- {
857
- "epoch": 1.12,
858
- "learning_rate": 8.493893463943617e-06,
859
- "loss": 0.8829,
860
- "step": 710
861
- },
862
- {
863
- "epoch": 1.13,
864
- "learning_rate": 8.367358352264834e-06,
865
- "loss": 0.7933,
866
- "step": 715
867
- },
868
- {
869
- "epoch": 1.14,
870
- "learning_rate": 8.241091246108796e-06,
871
- "loss": 0.9976,
872
- "step": 720
873
- },
874
- {
875
- "epoch": 1.15,
876
- "learning_rate": 8.115112872793006e-06,
877
- "loss": 0.9947,
878
- "step": 725
879
- },
880
- {
881
- "epoch": 1.15,
882
- "learning_rate": 7.989443912238151e-06,
883
- "loss": 0.9171,
884
- "step": 730
885
- },
886
- {
887
- "epoch": 1.16,
888
- "learning_rate": 7.864104993573422e-06,
889
- "loss": 0.9348,
890
- "step": 735
891
- },
892
- {
893
- "epoch": 1.17,
894
- "learning_rate": 7.73911669175013e-06,
895
- "loss": 0.9357,
896
- "step": 740
897
- },
898
- {
899
- "epoch": 1.18,
900
- "learning_rate": 7.614499524164251e-06,
901
- "loss": 0.9273,
902
- "step": 745
903
- },
904
- {
905
- "epoch": 1.19,
906
- "learning_rate": 7.490273947288389e-06,
907
- "loss": 0.8991,
908
- "step": 750
909
- },
910
- {
911
- "epoch": 1.19,
912
- "learning_rate": 7.366460353313762e-06,
913
- "loss": 0.8373,
914
- "step": 755
915
- },
916
- {
917
- "epoch": 1.2,
918
- "learning_rate": 7.2430790668027274e-06,
919
- "loss": 0.8316,
920
- "step": 760
921
- },
922
- {
923
- "epoch": 1.21,
924
- "learning_rate": 7.120150341352413e-06,
925
- "loss": 0.7532,
926
- "step": 765
927
- },
928
- {
929
- "epoch": 1.22,
930
- "learning_rate": 6.99769435627e-06,
931
- "loss": 0.8851,
932
- "step": 770
933
- },
934
- {
935
- "epoch": 1.23,
936
- "learning_rate": 6.875731213260193e-06,
937
- "loss": 0.9678,
938
- "step": 775
939
- },
940
- {
941
- "epoch": 1.23,
942
- "learning_rate": 6.754280933125441e-06,
943
- "loss": 0.9968,
944
- "step": 780
945
- },
946
- {
947
- "epoch": 1.24,
948
- "learning_rate": 6.633363452479431e-06,
949
- "loss": 0.9126,
950
- "step": 785
951
- },
952
- {
953
- "epoch": 1.25,
954
- "learning_rate": 6.512998620474396e-06,
955
- "loss": 0.8765,
956
- "step": 790
957
- },
958
- {
959
- "epoch": 1.26,
960
- "learning_rate": 6.393206195542791e-06,
961
- "loss": 0.8761,
962
- "step": 795
963
- },
964
- {
965
- "epoch": 1.27,
966
- "learning_rate": 6.27400584215386e-06,
967
- "loss": 0.99,
968
- "step": 800
969
- },
970
- {
971
- "epoch": 1.27,
972
- "learning_rate": 6.155417127585617e-06,
973
- "loss": 0.7691,
974
- "step": 805
975
- },
976
- {
977
- "epoch": 1.28,
978
- "learning_rate": 6.037459518712796e-06,
979
- "loss": 0.7794,
980
- "step": 810
981
- },
982
- {
983
- "epoch": 1.29,
984
- "learning_rate": 5.920152378811268e-06,
985
- "loss": 0.9978,
986
- "step": 815
987
- },
988
- {
989
- "epoch": 1.3,
990
- "learning_rate": 5.803514964379482e-06,
991
- "loss": 0.8165,
992
- "step": 820
993
- },
994
- {
995
- "epoch": 1.31,
996
- "learning_rate": 5.68756642197741e-06,
997
- "loss": 0.8918,
998
- "step": 825
999
- },
1000
- {
1001
- "epoch": 1.31,
1002
- "learning_rate": 5.572325785083563e-06,
1003
- "loss": 0.9572,
1004
- "step": 830
1005
- },
1006
- {
1007
- "epoch": 1.32,
1008
- "learning_rate": 5.457811970970564e-06,
1009
- "loss": 0.9112,
1010
- "step": 835
1011
- },
1012
- {
1013
- "epoch": 1.33,
1014
- "learning_rate": 5.3440437775997636e-06,
1015
- "loss": 0.9659,
1016
- "step": 840
1017
- },
1018
- {
1019
- "epoch": 1.34,
1020
- "learning_rate": 5.231039880535511e-06,
1021
- "loss": 0.9516,
1022
- "step": 845
1023
- },
1024
- {
1025
- "epoch": 1.34,
1026
- "learning_rate": 5.118818829879442e-06,
1027
- "loss": 1.076,
1028
- "step": 850
1029
- },
1030
- {
1031
- "epoch": 1.35,
1032
- "learning_rate": 5.0073990472254075e-06,
1033
- "loss": 0.8136,
1034
- "step": 855
1035
- },
1036
- {
1037
- "epoch": 1.36,
1038
- "learning_rate": 4.8967988226354945e-06,
1039
- "loss": 0.9711,
1040
- "step": 860
1041
- },
1042
- {
1043
- "epoch": 1.37,
1044
- "learning_rate": 4.787036311637609e-06,
1045
- "loss": 0.8879,
1046
- "step": 865
1047
- },
1048
- {
1049
- "epoch": 1.38,
1050
- "learning_rate": 4.678129532245189e-06,
1051
- "loss": 1.0423,
1052
- "step": 870
1053
- },
1054
- {
1055
- "epoch": 1.38,
1056
- "learning_rate": 4.570096361999445e-06,
1057
- "loss": 0.892,
1058
- "step": 875
1059
- },
1060
- {
1061
- "epoch": 1.39,
1062
- "learning_rate": 4.462954535034692e-06,
1063
- "loss": 0.9188,
1064
- "step": 880
1065
- },
1066
- {
1067
- "epoch": 1.4,
1068
- "learning_rate": 4.356721639167202e-06,
1069
- "loss": 0.8706,
1070
- "step": 885
1071
- },
1072
- {
1073
- "epoch": 1.41,
1074
- "learning_rate": 4.251415113008096e-06,
1075
- "loss": 0.8926,
1076
- "step": 890
1077
- },
1078
- {
1079
- "epoch": 1.42,
1080
- "learning_rate": 4.147052243100706e-06,
1081
- "loss": 0.9126,
1082
- "step": 895
1083
- },
1084
- {
1085
- "epoch": 1.42,
1086
- "learning_rate": 4.043650161082913e-06,
1087
- "loss": 0.9201,
1088
- "step": 900
1089
- },
1090
- {
1091
- "epoch": 1.43,
1092
- "learning_rate": 3.941225840874925e-06,
1093
- "loss": 0.8758,
1094
- "step": 905
1095
- },
1096
- {
1097
- "epoch": 1.44,
1098
- "learning_rate": 3.839796095892905e-06,
1099
- "loss": 0.8818,
1100
- "step": 910
1101
- },
1102
- {
1103
- "epoch": 1.45,
1104
- "learning_rate": 3.7393775762889963e-06,
1105
- "loss": 0.9018,
1106
- "step": 915
1107
- },
1108
- {
1109
- "epoch": 1.46,
1110
- "learning_rate": 3.639986766218112e-06,
1111
- "loss": 0.8346,
1112
- "step": 920
1113
- },
1114
- {
1115
- "epoch": 1.46,
1116
- "learning_rate": 3.541639981131996e-06,
1117
- "loss": 0.9619,
1118
- "step": 925
1119
- },
1120
- {
1121
- "epoch": 1.47,
1122
- "learning_rate": 3.4443533651009474e-06,
1123
- "loss": 1.0202,
1124
- "step": 930
1125
- },
1126
- {
1127
- "epoch": 1.48,
1128
- "learning_rate": 3.348142888163726e-06,
1129
- "loss": 0.9186,
1130
- "step": 935
1131
- },
1132
- {
1133
- "epoch": 1.49,
1134
- "learning_rate": 3.2530243437059773e-06,
1135
- "loss": 0.9326,
1136
- "step": 940
1137
- },
1138
- {
1139
- "epoch": 1.49,
1140
- "learning_rate": 3.1590133458676787e-06,
1141
- "loss": 0.8347,
1142
- "step": 945
1143
- },
1144
- {
1145
- "epoch": 1.5,
1146
- "learning_rate": 3.066125326980027e-06,
1147
- "loss": 0.8914,
1148
- "step": 950
1149
- },
1150
- {
1151
- "epoch": 1.51,
1152
- "learning_rate": 2.9743755350321213e-06,
1153
- "loss": 0.8869,
1154
- "step": 955
1155
- },
1156
- {
1157
- "epoch": 1.52,
1158
- "learning_rate": 2.8837790311679625e-06,
1159
- "loss": 0.9833,
1160
- "step": 960
1161
- },
1162
- {
1163
- "epoch": 1.53,
1164
- "learning_rate": 2.7943506872140844e-06,
1165
- "loss": 0.8893,
1166
- "step": 965
1167
- },
1168
- {
1169
- "epoch": 1.53,
1170
- "learning_rate": 2.7061051832382836e-06,
1171
- "loss": 0.8616,
1172
- "step": 970
1173
- },
1174
- {
1175
- "epoch": 1.54,
1176
- "learning_rate": 2.6190570051398035e-06,
1177
- "loss": 0.9909,
1178
- "step": 975
1179
- },
1180
- {
1181
- "epoch": 1.55,
1182
- "learning_rate": 2.5332204422714368e-06,
1183
- "loss": 0.9515,
1184
- "step": 980
1185
- },
1186
- {
1187
- "epoch": 1.56,
1188
- "learning_rate": 2.4486095850938352e-06,
1189
- "loss": 1.0862,
1190
- "step": 985
1191
- },
1192
- {
1193
- "epoch": 1.57,
1194
- "learning_rate": 2.365238322862511e-06,
1195
- "loss": 0.9123,
1196
- "step": 990
1197
- },
1198
- {
1199
- "epoch": 1.57,
1200
- "learning_rate": 2.2831203413478555e-06,
1201
- "loss": 0.8965,
1202
- "step": 995
1203
- },
1204
- {
1205
- "epoch": 1.58,
1206
- "learning_rate": 2.202269120588546e-06,
1207
- "loss": 0.9261,
1208
- "step": 1000
1209
- },
1210
- {
1211
- "epoch": 1.59,
1212
- "learning_rate": 2.122697932678748e-06,
1213
- "loss": 0.9044,
1214
- "step": 1005
1215
- },
1216
- {
1217
- "epoch": 1.6,
1218
- "learning_rate": 2.0444198395894332e-06,
1219
- "loss": 0.923,
1220
- "step": 1010
1221
- },
1222
- {
1223
- "epoch": 1.61,
1224
- "learning_rate": 1.9674476910242055e-06,
1225
- "loss": 0.9581,
1226
- "step": 1015
1227
- },
1228
- {
1229
- "epoch": 1.61,
1230
- "learning_rate": 1.891794122309949e-06,
1231
- "loss": 1.0578,
1232
- "step": 1020
1233
- },
1234
- {
1235
- "epoch": 1.62,
1236
- "learning_rate": 1.8174715523227017e-06,
1237
- "loss": 0.886,
1238
- "step": 1025
1239
- },
1240
- {
1241
- "epoch": 1.63,
1242
- "learning_rate": 1.7444921814490256e-06,
1243
- "loss": 0.8365,
1244
- "step": 1030
1245
- },
1246
- {
1247
- "epoch": 1.64,
1248
- "learning_rate": 1.6728679895832622e-06,
1249
- "loss": 0.9913,
1250
- "step": 1035
1251
- },
1252
- {
1253
- "epoch": 1.65,
1254
- "learning_rate": 1.6026107341609842e-06,
1255
- "loss": 0.8921,
1256
- "step": 1040
1257
- },
1258
- {
1259
- "epoch": 1.65,
1260
- "learning_rate": 1.5337319482289503e-06,
1261
- "loss": 0.8323,
1262
- "step": 1045
1263
- },
1264
- {
1265
- "epoch": 1.66,
1266
- "learning_rate": 1.4662429385519084e-06,
1267
- "loss": 0.8869,
1268
- "step": 1050
1269
- },
1270
- {
1271
- "epoch": 1.67,
1272
- "learning_rate": 1.400154783756541e-06,
1273
- "loss": 0.9799,
1274
- "step": 1055
1275
- },
1276
- {
1277
- "epoch": 1.68,
1278
- "learning_rate": 1.3354783325128561e-06,
1279
- "loss": 0.9416,
1280
- "step": 1060
1281
- },
1282
- {
1283
- "epoch": 1.68,
1284
- "learning_rate": 1.2722242017533192e-06,
1285
- "loss": 0.9801,
1286
- "step": 1065
1287
- },
1288
- {
1289
- "epoch": 1.69,
1290
- "learning_rate": 1.2104027749300574e-06,
1291
- "loss": 0.9614,
1292
- "step": 1070
1293
- },
1294
- {
1295
- "epoch": 1.7,
1296
- "learning_rate": 1.150024200310348e-06,
1297
- "loss": 1.0792,
1298
- "step": 1075
1299
- },
1300
- {
1301
- "epoch": 1.71,
1302
- "learning_rate": 1.0910983893107419e-06,
1303
- "loss": 0.8622,
1304
- "step": 1080
1305
- },
1306
- {
1307
- "epoch": 1.72,
1308
- "learning_rate": 1.0336350148700668e-06,
1309
- "loss": 0.8627,
1310
- "step": 1085
1311
- },
1312
- {
1313
- "epoch": 1.72,
1314
- "learning_rate": 9.776435098615578e-07,
1315
- "loss": 0.857,
1316
- "step": 1090
1317
- },
1318
- {
1319
- "epoch": 1.73,
1320
- "learning_rate": 9.231330655444193e-07,
1321
- "loss": 1.0396,
1322
- "step": 1095
1323
- },
1324
- {
1325
- "epoch": 1.74,
1326
- "learning_rate": 8.701126300550322e-07,
1327
- "loss": 1.0174,
1328
- "step": 1100
1329
- },
1330
- {
1331
- "epoch": 1.75,
1332
- "learning_rate": 8.185909069380782e-07,
1333
- "loss": 1.031,
1334
- "step": 1105
1335
- },
1336
- {
1337
- "epoch": 1.76,
1338
- "learning_rate": 7.685763537178093e-07,
1339
- "loss": 0.9254,
1340
- "step": 1110
1341
- },
1342
- {
1343
- "epoch": 1.76,
1344
- "learning_rate": 7.200771805097206e-07,
1345
- "loss": 0.9294,
1346
- "step": 1115
1347
- },
1348
- {
1349
- "epoch": 1.77,
1350
- "learning_rate": 6.731013486728044e-07,
1351
- "loss": 0.8641,
1352
- "step": 1120
1353
- },
1354
- {
1355
- "epoch": 1.78,
1356
- "learning_rate": 6.276565695026671e-07,
1357
- "loss": 0.9669,
1358
- "step": 1125
1359
- },
1360
- {
1361
- "epoch": 1.79,
1362
- "learning_rate": 5.837503029656888e-07,
1363
- "loss": 0.8447,
1364
- "step": 1130
1365
- },
1366
- {
1367
- "epoch": 1.8,
1368
- "learning_rate": 5.413897564744253e-07,
1369
- "loss": 0.9566,
1370
- "step": 1135
1371
- },
1372
- {
1373
- "epoch": 1.8,
1374
- "learning_rate": 5.005818837044885e-07,
1375
- "loss": 0.9457,
1376
- "step": 1140
1377
- },
1378
- {
1379
- "epoch": 1.81,
1380
- "learning_rate": 4.613333834530631e-07,
1381
- "loss": 0.8523,
1382
- "step": 1145
1383
- },
1384
- {
1385
- "epoch": 1.82,
1386
- "learning_rate": 4.23650698539273e-07,
1387
- "loss": 0.9696,
1388
- "step": 1150
1389
- },
1390
- {
1391
- "epoch": 1.83,
1392
- "learning_rate": 3.8754001474655354e-07,
1393
- "loss": 0.7867,
1394
- "step": 1155
1395
- },
1396
- {
1397
- "epoch": 1.83,
1398
- "learning_rate": 3.530072598072454e-07,
1399
- "loss": 0.9534,
1400
- "step": 1160
1401
- },
1402
- {
1403
- "epoch": 1.84,
1404
- "learning_rate": 3.200581024295102e-07,
1405
- "loss": 0.9209,
1406
- "step": 1165
1407
- },
1408
- {
1409
- "epoch": 1.85,
1410
- "learning_rate": 2.886979513667998e-07,
1411
- "loss": 0.8257,
1412
- "step": 1170
1413
- },
1414
- {
1415
- "epoch": 1.86,
1416
- "learning_rate": 2.589319545299807e-07,
1417
- "loss": 0.9424,
1418
- "step": 1175
1419
- },
1420
- {
1421
- "epoch": 1.87,
1422
- "learning_rate": 2.3076499814227992e-07,
1423
- "loss": 0.873,
1424
- "step": 1180
1425
- },
1426
- {
1427
- "epoch": 1.87,
1428
- "learning_rate": 2.042017059371948e-07,
1429
- "loss": 0.843,
1430
- "step": 1185
1431
- },
1432
- {
1433
- "epoch": 1.88,
1434
- "learning_rate": 1.7924643839947632e-07,
1435
- "loss": 0.9136,
1436
- "step": 1190
1437
- },
1438
- {
1439
- "epoch": 1.89,
1440
- "learning_rate": 1.559032920493464e-07,
1441
- "loss": 1.0255,
1442
- "step": 1195
1443
- },
1444
- {
1445
- "epoch": 1.9,
1446
- "learning_rate": 1.3417609877002691e-07,
1447
- "loss": 0.8384,
1448
- "step": 1200
1449
- },
1450
- {
1451
- "epoch": 1.91,
1452
- "learning_rate": 1.1406842517872608e-07,
1453
- "loss": 0.7154,
1454
- "step": 1205
1455
- },
1456
- {
1457
- "epoch": 1.91,
1458
- "learning_rate": 9.558357204115464e-08,
1459
- "loss": 0.8805,
1460
- "step": 1210
1461
- },
1462
- {
1463
- "epoch": 1.92,
1464
- "learning_rate": 7.872457372969711e-08,
1465
- "loss": 0.7834,
1466
- "step": 1215
1467
- },
1468
- {
1469
- "epoch": 1.93,
1470
- "learning_rate": 6.34941977253023e-08,
1471
- "loss": 0.8752,
1472
- "step": 1220
1473
- },
1474
- {
1475
- "epoch": 1.94,
1476
- "learning_rate": 4.989494416318685e-08,
1477
- "loss": 0.9148,
1478
- "step": 1225
1479
- },
1480
- {
1481
- "epoch": 1.95,
1482
- "learning_rate": 3.7929045422432364e-08,
1483
- "loss": 0.8612,
1484
- "step": 1230
1485
- },
1486
- {
1487
- "epoch": 1.95,
1488
- "learning_rate": 2.7598465759526294e-08,
1489
- "loss": 0.9299,
1490
- "step": 1235
1491
- },
1492
- {
1493
- "epoch": 1.96,
1494
- "learning_rate": 1.8904900985918796e-08,
1495
- "loss": 0.9918,
1496
- "step": 1240
1497
- },
1498
- {
1499
- "epoch": 1.97,
1500
- "learning_rate": 1.184977818965205e-08,
1501
- "loss": 0.9008,
1502
- "step": 1245
1503
- },
1504
- {
1505
- "epoch": 1.98,
1506
- "learning_rate": 6.434255501095443e-09,
1507
- "loss": 0.9461,
1508
- "step": 1250
1509
- },
1510
- {
1511
- "epoch": 1.99,
1512
- "learning_rate": 2.659221902830966e-09,
1513
- "loss": 0.8181,
1514
- "step": 1255
1515
- },
1516
- {
1517
- "epoch": 1.99,
1518
- "learning_rate": 5.252970837255067e-10,
1519
- "loss": 0.8643,
1520
- "step": 1260
1521
- },
1522
- {
1523
- "epoch": 2.0,
1524
- "step": 1264,
1525
- "total_flos": 3.882961297263821e+17,
1526
- "train_loss": 0.9575079066466682,
1527
- "train_runtime": 77706.2136,
1528
- "train_samples_per_second": 0.26,
1529
- "train_steps_per_second": 0.016
1530
  }
1531
  ],
1532
- "max_steps": 1264,
1533
- "num_train_epochs": 2,
1534
- "total_flos": 3.882961297263821e+17,
1535
  "trial_name": null,
1536
  "trial_params": null
1537
  }
1
  {
2
  "best_metric": null,
3
  "best_model_checkpoint": null,
4
+ "epoch": 2.9914638001896936,
5
+ "global_step": 294,
6
  "is_hyper_param_search": false,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [
10
  {
11
+ "epoch": 0.05,
12
+ "learning_rate": 7.5e-06,
13
+ "loss": 0.481,
14
  "step": 5
15
  },
16
  {
17
+ "epoch": 0.1,
18
+ "learning_rate": 1.5e-05,
19
+ "loss": 0.5091,
20
  "step": 10
21
  },
22
  {
23
+ "epoch": 0.15,
24
+ "learning_rate": 2.25e-05,
25
+ "loss": 0.4303,
26
  "step": 15
27
  },
28
  {
29
+ "epoch": 0.2,
30
+ "learning_rate": 3e-05,
31
+ "loss": 0.4055,
32
  "step": 20
33
  },
34
  {
35
+ "epoch": 0.25,
36
+ "learning_rate": 2.998662940889891e-05,
37
+ "loss": 0.4338,
38
  "step": 25
39
  },
40
  {
41
+ "epoch": 0.3,
42
+ "learning_rate": 2.9946541471956496e-05,
43
+ "loss": 0.4793,
44
  "step": 30
45
  },
46
  {
47
+ "epoch": 0.35,
48
+ "learning_rate": 2.9879807655761145e-05,
49
+ "loss": 0.4291,
50
  "step": 35
51
  },
52
  {
53
+ "epoch": 0.4,
54
+ "learning_rate": 2.9786546929722055e-05,
55
+ "loss": 0.4908,
56
  "step": 40
57
  },
58
  {
59
+ "epoch": 0.46,
60
+ "learning_rate": 2.966692555397705e-05,
61
+ "loss": 0.463,
62
  "step": 45
63
  },
64
  {
65
+ "epoch": 0.51,
66
+ "learning_rate": 2.9521156782993066e-05,
67
+ "loss": 0.528,
68
  "step": 50
69
  },
70
  {
71
+ "epoch": 0.56,
72
+ "learning_rate": 2.9349500485387718e-05,
73
+ "loss": 0.5178,
74
  "step": 55
75
  },
76
  {
77
+ "epoch": 0.61,
78
+ "learning_rate": 2.9152262680649704e-05,
79
+ "loss": 0.4602,
80
  "step": 60
81
  },
82
  {
83
+ "epoch": 0.66,
84
+ "learning_rate": 2.8929794993583937e-05,
85
+ "loss": 0.5044,
86
  "step": 65
87
  },
88
  {
89
+ "epoch": 0.71,
90
+ "learning_rate": 2.8682494027454e-05,
91
+ "loss": 0.4217,
92
  "step": 70
93
  },
94
  {
95
+ "epoch": 0.76,
96
+ "learning_rate": 2.8410800656939512e-05,
97
+ "loss": 0.502,
98
  "step": 75
99
  },
100
  {
101
+ "epoch": 0.81,
102
+ "learning_rate": 2.811519924216873e-05,
103
+ "loss": 0.4549,
104
  "step": 80
105
  },
106
  {
107
+ "epoch": 0.86,
108
+ "learning_rate": 2.779621676522777e-05,
109
+ "loss": 0.5692,
110
  "step": 85
111
  },
112
  {
113
+ "epoch": 0.91,
114
+ "learning_rate": 2.7454421890685647e-05,
115
+ "loss": 0.4312,
116
  "step": 90
117
  },
118
  {
119
+ "epoch": 0.96,
120
+ "learning_rate": 2.709042395181008e-05,
121
+ "loss": 0.4938,
122
  "step": 95
123
  },
124
  {
125
+ "epoch": 1.02,
126
+ "learning_rate": 2.6704871864281377e-05,
127
+ "loss": 0.5433,
128
  "step": 100
129
  },
130
  {
131
+ "epoch": 1.07,
132
+ "learning_rate": 2.6298452969340952e-05,
133
+ "loss": 0.3459,
134
  "step": 105
135
  },
136
  {
137
+ "epoch": 1.12,
138
+ "learning_rate": 2.58718918084368e-05,
139
+ "loss": 0.3739,
140
  "step": 110
141
  },
142
  {
143
+ "epoch": 1.17,
144
+ "learning_rate": 2.5425948831550528e-05,
145
+ "loss": 0.375,
146
  "step": 115
147
  },
148
  {
149
+ "epoch": 1.22,
150
+ "learning_rate": 2.496141904150859e-05,
151
+ "loss": 0.3809,
152
  "step": 120
153
  },
154
  {
155
+ "epoch": 1.27,
156
+ "learning_rate": 2.447913057669456e-05,
157
+ "loss": 0.4183,
158
  "step": 125
159
  },
160
  {
161
+ "epoch": 1.32,
162
+ "learning_rate": 2.3979943234689226e-05,
163
+ "loss": 0.4207,
164
  "step": 130
165
  },
166
  {
167
+ "epoch": 1.37,
168
+ "learning_rate": 2.3464746939470288e-05,
169
+ "loss": 0.3767,
170
  "step": 135
171
  },
172
  {
173
+ "epoch": 1.42,
174
+ "learning_rate": 2.2934460154904436e-05,
175
+ "loss": 0.4248,
176
  "step": 140
177
  },
178
  {
179
+ "epoch": 1.48,
180
+ "learning_rate": 2.2390028247360042e-05,
181
+ "loss": 0.3374,
182
  "step": 145
183
  },
184
  {
185
+ "epoch": 1.53,
186
+ "learning_rate": 2.183242180035951e-05,
187
+ "loss": 0.4582,
188
  "step": 150
189
  },
190
  {
191
+ "epoch": 1.58,
192
+ "learning_rate": 2.1262634884275948e-05,
193
+ "loss": 0.4153,
194
  "step": 155
195
  },
196
  {
197
+ "epoch": 1.63,
198
+ "learning_rate": 2.068168328415864e-05,
199
+ "loss": 0.409,
200
  "step": 160
201
  },
202
  {
203
+ "epoch": 1.68,
204
+ "learning_rate": 2.0090602688846884e-05,
205
+ "loss": 0.4023,
206
  "step": 165
207
  },
208
  {
209
+ "epoch": 1.73,
210
+ "learning_rate": 1.9490446844600375e-05,
211
+ "loss": 0.3426,
212
  "step": 170
213
  },
214
  {
215
+ "epoch": 1.78,
216
+ "learning_rate": 1.888228567653781e-05,
217
+ "loss": 0.4059,
218
  "step": 175
219
  },
220
  {
221
+ "epoch": 1.83,
222
+ "learning_rate": 1.8267203381232774e-05,
223
+ "loss": 0.4449,
224
  "step": 180
225
  },
226
  {
227
+ "epoch": 1.88,
228
+ "learning_rate": 1.764629649386713e-05,
229
+ "loss": 0.4362,
230
  "step": 185
231
  },
232
  {
233
+ "epoch": 1.93,
234
+ "learning_rate": 1.7020671933387917e-05,
235
+ "loss": 0.4874,
236
  "step": 190
237
  },
238
  {
239
+ "epoch": 1.98,
240
+ "learning_rate": 1.63914450291526e-05,
241
+ "loss": 0.3326,
242
  "step": 195
243
  },
244
  {
245
+ "epoch": 2.04,
246
+ "learning_rate": 1.5633197410233404e-05,
247
+ "loss": 0.4035,
248
  "step": 200
249
  },
250
  {
251
+ "epoch": 2.09,
252
+ "learning_rate": 1.5e-05,
253
+ "loss": 0.3291,
254
  "step": 205
255
  },
256
  {
257
+ "epoch": 2.14,
258
+ "learning_rate": 1.4366802589766598e-05,
259
+ "loss": 0.353,
260
  "step": 210
261
  },
262
  {
263
+ "epoch": 2.19,
264
+ "learning_rate": 1.373473400935433e-05,
265
+ "loss": 0.319,
266
  "step": 215
267
  },
268
  {
269
+ "epoch": 2.24,
270
+ "learning_rate": 1.3104921076168065e-05,
271
+ "loss": 0.341,
272
  "step": 220
273
  },
274
  {
275
+ "epoch": 2.29,
276
+ "learning_rate": 1.247848658636778e-05,
277
+ "loss": 0.3276,
278
  "step": 225
279
  },
280
  {
281
+ "epoch": 2.34,
282
+ "learning_rate": 1.185654731320877e-05,
283
+ "loss": 0.3628,
284
  "step": 230
285
  },
286
  {
287
+ "epoch": 2.39,
288
+ "learning_rate": 1.124021201611919e-05,
289
+ "loss": 0.2727,
290
  "step": 235
291
  },
292
  {
293
+ "epoch": 2.45,
294
+ "learning_rate": 1.0630579464064182e-05,
295
+ "loss": 0.3466,
296
  "step": 240
297
  },
298
  {
299
+ "epoch": 2.5,
300
+ "learning_rate": 1.0028736476720464e-05,
301
+ "loss": 0.3187,
302
  "step": 245
303
  },
304
  {
305
+ "epoch": 2.55,
306
+ "learning_rate": 9.435755986953485e-06,
307
+ "loss": 0.3837,
308
  "step": 250
309
  },
310
  {
311
+ "epoch": 2.6,
312
+ "learning_rate": 8.852695128051192e-06,
313
+ "loss": 0.2955,
314
  "step": 255
315
  },
316
  {
317
+ "epoch": 2.65,
318
+ "learning_rate": 8.280593349124432e-06,
319
+ "loss": 0.3793,
320
  "step": 260
321
  },
322
  {
323
+ "epoch": 2.7,
324
+ "learning_rate": 7.720470562033787e-06,
325
+ "loss": 0.3443,
326
  "step": 265
327
  },
328
  {
329
+ "epoch": 2.75,
330
+ "learning_rate": 7.17332532314626e-06,
331
+ "loss": 0.2915,
332
  "step": 270
333
  },
334
  {
335
+ "epoch": 2.8,
336
+ "learning_rate": 6.640133053163455e-06,
337
+ "loss": 0.3514,
338
  "step": 275
339
  },
340
  {
341
+ "epoch": 2.85,
342
+ "learning_rate": 6.12184429819474e-06,
343
+ "loss": 0.3221,
344
  "step": 280
345
  },
346
  {
347
+ "epoch": 2.9,
348
+ "learning_rate": 5.619383035175448e-06,
349
+ "loss": 0.2903,
350
  "step": 285
351
  },
352
  {
353
+ "epoch": 2.95,
354
+ "learning_rate": 5.133645024651171e-06,
355
+ "loss": 0.3397,
356
  "step": 290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  }
358
  ],
359
+ "max_steps": 392,
360
+ "num_train_epochs": 4,
361
+ "total_flos": 1.821325738775675e+17,
362
  "trial_name": null,
363
  "trial_params": null
364
  }
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c2c8d84c6ea850e743bbd2858bd2ac8b6913bd2628de22bb48d811f906d778c2
3
  size 4399
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb4937467d6816f4a87c65a795706d60d5c1945956041be1f6697c8ed7d29b1c
3
  size 4399
zero_to_fp32.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
4
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
5
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
6
+ # application.
7
+ #
8
+ # example: python zero_to_fp32.py . pytorch_model.bin
9
+
10
+ import argparse
11
+ import torch
12
+ import glob
13
+ import math
14
+ import os
15
+ import re
16
+ from collections import OrderedDict
17
+
18
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
19
+ # DeepSpeed data structures it has to be available in the current python environment.
20
+ import deepspeed
21
+ from deepspeed.utils import logger
22
+ from deepspeed.checkpoint.constants import (DS_VERSION,
23
+ OPTIMIZER_STATE_DICT,
24
+ PARAM_SHAPES,
25
+ SINGLE_PARTITION_OF_FP32_GROUPS,
26
+ FP32_FLAT_GROUPS,
27
+ ZERO_STAGE,
28
+ PARTITION_COUNT,
29
+ PARAM_SHAPES,
30
+ BUFFER_NAMES)
31
+
32
+ debug = 0
33
+
34
+ # load to cpu
35
+ device = torch.device('cpu')
36
+
37
+
38
+ def atoi(text):
39
+ return int(text) if text.isdigit() else text
40
+
41
+
42
+ def natural_keys(text):
43
+ '''
44
+ alist.sort(key=natural_keys) sorts in human order
45
+ http://nedbatchelder.com/blog/200712/human_sorting.html
46
+ (See Toothy's implementation in the comments)
47
+ '''
48
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
49
+
50
+
51
+ def get_model_state_file(checkpoint_dir, zero_stage):
52
+ if not os.path.isdir(checkpoint_dir):
53
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
54
+
55
+ # there should be only one file
56
+ if zero_stage == 2:
57
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
58
+ elif zero_stage == 3:
59
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
60
+
61
+ if not os.path.exists(file):
62
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
63
+
64
+ return file
65
+
66
+
67
+ def get_optim_files(checkpoint_dir):
68
+ # XXX: need to test that this simple glob rule works for multi-node setup too
69
+ optim_files = sorted(glob.glob(os.path.join(checkpoint_dir,
70
+ "*_optim_states.pt")),
71
+ key=natural_keys)
72
+
73
+ if len(optim_files) == 0:
74
+ raise FileNotFoundError(
75
+ f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
76
+
77
+ return optim_files
78
+
79
+
80
+ def parse_model_state(file):
81
+ state_dict = torch.load(file, map_location=device)
82
+
83
+ if BUFFER_NAMES not in state_dict:
84
+ raise ValueError(f"{file} is not a model state checkpoint")
85
+ buffer_names = state_dict[BUFFER_NAMES]
86
+ if debug:
87
+ print("Found buffers:", buffer_names)
88
+
89
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
90
+ buffers = {
91
+ k: v.float()
92
+ for k,
93
+ v in state_dict["module"].items() if k in buffer_names
94
+ }
95
+ param_shapes = state_dict[PARAM_SHAPES]
96
+
97
+ ds_version = state_dict.get(DS_VERSION, None)
98
+
99
+ return buffers, param_shapes, ds_version
100
+
101
+
102
+ def parse_optim_states(files, ds_checkpoint_dir):
103
+
104
+ total_files = len(files)
105
+ state_dicts = []
106
+ for f in files:
107
+ state_dicts.append(torch.load(f, map_location=device))
108
+
109
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
110
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
111
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
112
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
113
+
114
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
115
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
116
+ # use the max of the partition_count to get the dp world_size.
117
+
118
+ if type(world_size) is list:
119
+ world_size = max(world_size)
120
+
121
+ if world_size != total_files:
122
+ raise ValueError(
123
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
124
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
125
+ )
126
+
127
+ # the groups are named differently in each stage
128
+ if zero_stage == 2:
129
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
130
+ elif zero_stage == 3:
131
+ fp32_groups_key = FP32_FLAT_GROUPS
132
+ else:
133
+ raise ValueError(f"unknown zero stage {zero_stage}")
134
+
135
+ if zero_stage == 2:
136
+ fp32_flat_groups = [
137
+ state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key]
138
+ for i in range(len(state_dicts))
139
+ ]
140
+ elif zero_stage == 3:
141
+ # if there is more than one param group, there will be multiple flattened tensors - one
142
+ # flattened tensor per group - for simplicity merge them into a single tensor
143
+ #
144
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
145
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
146
+
147
+ fp32_flat_groups = [
148
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key],
149
+ 0) for i in range(len(state_dicts))
150
+ ]
151
+
152
+ return zero_stage, world_size, fp32_flat_groups
153
+
154
+
155
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
156
+ """
157
+ Returns fp32 state_dict reconstructed from ds checkpoint
158
+
159
+ Args:
160
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
161
+
162
+ """
163
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
164
+
165
+ optim_files = get_optim_files(ds_checkpoint_dir)
166
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
167
+ print(
168
+ f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
169
+
170
+ model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
171
+ buffers, param_shapes, ds_version = parse_model_state(model_file)
172
+ print(f'Parsing checkpoint created by deepspeed=={ds_version}')
173
+
174
+ if zero_stage == 2:
175
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
176
+ param_shapes,
177
+ fp32_flat_groups,
178
+ buffers)
179
+ elif zero_stage == 3:
180
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
181
+ param_shapes,
182
+ fp32_flat_groups,
183
+ buffers)
184
+
185
+
186
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
187
+ param_shapes,
188
+ fp32_flat_groups,
189
+ buffers):
190
+
191
+ # Reconstruction protocol:
192
+ #
193
+ # XXX: document this
194
+
195
+ if debug:
196
+ for i in range(world_size):
197
+ for j in range(len(fp32_flat_groups[0])):
198
+ print(
199
+ f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
200
+
201
+ # XXX: memory usage doubles here (zero2)
202
+ num_param_groups = len(fp32_flat_groups[0])
203
+ merged_single_partition_of_fp32_groups = []
204
+ for i in range(num_param_groups):
205
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
206
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
207
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
208
+ avail_numel = sum([
209
+ full_single_fp32_vector.numel()
210
+ for full_single_fp32_vector in merged_single_partition_of_fp32_groups
211
+ ])
212
+
213
+ if debug:
214
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
215
+ wanted_numel = sum(
216
+ [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
217
+ # not asserting if there is a mismatch due to possible padding
218
+ print(f"Have {avail_numel} numels to process.")
219
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
220
+
221
+ state_dict = OrderedDict()
222
+
223
+ # buffers
224
+ state_dict.update(buffers)
225
+ if debug:
226
+ print(f"added {len(buffers)} buffers")
227
+
228
+ # params
229
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
230
+ # out-of-core computing solution
231
+ total_numel = 0
232
+ total_params = 0
233
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
234
+ offset = 0
235
+ avail_numel = full_single_fp32_vector.numel()
236
+ for name, shape in shapes.items():
237
+
238
+ unpartitioned_numel = shape.numel()
239
+ total_numel += unpartitioned_numel
240
+ total_params += 1
241
+
242
+ if debug:
243
+ print(
244
+ f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
245
+ )
246
+ state_dict[name] = full_single_fp32_vector.narrow(
247
+ 0,
248
+ offset,
249
+ unpartitioned_numel).view(shape)
250
+ offset += unpartitioned_numel
251
+
252
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
253
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
254
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
255
+ # live optimizer object, so we are checking that the numbers are within the right range
256
+ align_to = 2 * world_size
257
+
258
+ def zero2_align(x):
259
+ return align_to * math.ceil(x / align_to)
260
+
261
+ if debug:
262
+ print(f"original offset={offset}, avail_numel={avail_numel}")
263
+
264
+ offset = zero2_align(offset)
265
+ avail_numel = zero2_align(avail_numel)
266
+
267
+ if debug:
268
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
269
+
270
+ # Sanity check
271
+ if offset != avail_numel:
272
+ raise ValueError(
273
+ f"consumed {offset} numels out of {avail_numel} - something is wrong")
274
+
275
+ print(
276
+ f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
277
+ )
278
+
279
+ return state_dict
280
+
281
+
282
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
283
+ remainder = unpartitioned_numel % world_size
284
+ padding_numel = (world_size - remainder) if remainder else 0
285
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
286
+ return partitioned_numel, padding_numel
287
+
288
+
289
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
290
+ param_shapes,
291
+ fp32_flat_groups,
292
+ buffers):
293
+
294
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
295
+ # param, re-consolidating each param, while dealing with padding if any
296
+
297
+ avail_numel = fp32_flat_groups[0].numel() * world_size
298
+ # merge list of dicts, preserving order
299
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
300
+
301
+ if debug:
302
+ for i in range(world_size):
303
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
304
+
305
+ wanted_params = len(param_shapes)
306
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
307
+ # not asserting if there is a mismatch due to possible padding
308
+ print(f"Have {avail_numel} numels to process.")
309
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
310
+
311
+ state_dict = OrderedDict()
312
+
313
+ # buffers
314
+ state_dict.update(buffers)
315
+ if debug:
316
+ print(f"added {len(buffers)} buffers")
317
+
318
+ # params
319
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
320
+ # out-of-core computing solution
321
+ offset = 0
322
+ total_numel = 0
323
+ total_params = 0
324
+ for name, shape in param_shapes.items():
325
+
326
+ unpartitioned_numel = shape.numel()
327
+ total_numel += unpartitioned_numel
328
+ total_params += 1
329
+
330
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
331
+
332
+ if debug:
333
+ print(
334
+ f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
335
+ )
336
+
337
+ # XXX: memory usage doubles here
338
+ state_dict[name] = torch.cat(
339
+ tuple(fp32_flat_groups[i].narrow(0,
340
+ offset,
341
+ partitioned_numel)
342
+ for i in range(world_size)),
343
+ 0).narrow(0,
344
+ 0,
345
+ unpartitioned_numel).view(shape)
346
+ offset += partitioned_numel
347
+
348
+ offset *= world_size
349
+
350
+ # Sanity check
351
+ if offset != avail_numel:
352
+ raise ValueError(
353
+ f"consumed {offset} numels out of {avail_numel} - something is wrong")
354
+
355
+ print(
356
+ f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
357
+ )
358
+
359
+ return state_dict
360
+
361
+
362
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
363
+ """
364
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
365
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
366
+ via a model hub.
367
+
368
+ Args:
369
+ - ``checkpoint_dir``: path to the desired checkpoint folder
370
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
371
+
372
+ Returns:
373
+ - pytorch ``state_dict``
374
+
375
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
376
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
377
+ the checkpoint.
378
+
379
+ A typical usage might be ::
380
+
381
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
382
+ # do the training and checkpoint saving
383
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
384
+ model = model.cpu() # move to cpu
385
+ model.load_state_dict(state_dict)
386
+ # submit to model hub or save the model to share with others
387
+
388
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
389
+ application. i.e. you will need to re-initialize the deepspeed engine, since
390
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
391
+
392
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
393
+
394
+ """
395
+ if tag is None:
396
+ latest_path = os.path.join(checkpoint_dir, 'latest')
397
+ if os.path.isfile(latest_path):
398
+ with open(latest_path, 'r') as fd:
399
+ tag = fd.read().strip()
400
+ else:
401
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
402
+
403
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
404
+
405
+ if not os.path.isdir(ds_checkpoint_dir):
406
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
407
+
408
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
409
+
410
+
411
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
412
+ """
413
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
414
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
415
+
416
+ Args:
417
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
418
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
419
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
420
+ """
421
+
422
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
423
+ print(f"Saving fp32 state dict to {output_file}")
424
+ torch.save(state_dict, output_file)
425
+
426
+
427
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
428
+ """
429
+ 1. Put the provided model to cpu
430
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
431
+ 3. Load it into the provided model
432
+
433
+ Args:
434
+ - ``model``: the model object to update
435
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
436
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
437
+
438
+ Returns:
439
+ - ``model`: modified model
440
+
441
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
442
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
443
+ conveniently placed for you in the checkpoint folder.
444
+
445
+ A typical usage might be ::
446
+
447
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
448
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
449
+ # submit to model hub or save the model to share with others
450
+
451
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
452
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
453
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
454
+
455
+ """
456
+ logger.info(f"Extracting fp32 weights")
457
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
458
+
459
+ logger.info(f"Overwriting model with fp32 weights")
460
+ model = model.cpu()
461
+ model.load_state_dict(state_dict, strict=False)
462
+
463
+ return model
464
+
465
+
466
+ if __name__ == "__main__":
467
+
468
+ parser = argparse.ArgumentParser()
469
+ parser.add_argument(
470
+ "checkpoint_dir",
471
+ type=str,
472
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
473
+ parser.add_argument(
474
+ "output_file",
475
+ type=str,
476
+ help=
477
+ "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
478
+ )
479
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
480
+ args = parser.parse_args()
481
+
482
+ debug = args.debug
483
+
484
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)