pszemraj commited on
Commit
f070b50
·
1 Parent(s): da07eb1

upd chkpt +3 epochs additional finetuning

Browse files
.gitattributes CHANGED
@@ -9,10 +9,14 @@
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.onnx filter=lfs diff=lfs merge=lfs -text
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
15
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
  *.rar filter=lfs diff=lfs merge=lfs -text
@@ -23,5 +27,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
23
  *.wasm filter=lfs diff=lfs merge=lfs -text
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
17
  *.pb filter=lfs diff=lfs merge=lfs -text
18
+ *.pickle filter=lfs diff=lfs merge=lfs -text
19
+ *.pkl filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
 
27
  *.wasm filter=lfs diff=lfs merge=lfs -text
28
  *.xz filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "pszemraj/grammar-synthesis-base-V1",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
@@ -30,7 +30,7 @@
30
  "relative_attention_num_buckets": 32,
31
  "tie_word_embeddings": false,
32
  "torch_dtype": "float32",
33
- "transformers_version": "4.20.1",
34
  "use_cache": false,
35
  "vocab_size": 32128
36
  }
 
1
  {
2
+ "_name_or_path": "pszemraj/grammar-synthesis-base-V2",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
 
30
  "relative_attention_num_buckets": 32,
31
  "tie_word_embeddings": false,
32
  "torch_dtype": "float32",
33
+ "transformers_version": "4.21.1",
34
  "use_cache": false,
35
  "vocab_size": 32128
36
  }
latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step350
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bda17b0c9cbe3d287a134c92b7f8390bd8015c0dfb7f0d6906fa34a49441f9e8
3
  size 990347691
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ce6574f3eca389526e65d8d577deb76467a77c374a05ef9933687ca4cb4456c
3
  size 990347691
rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c29f7470f494caf666906f135f44706bfe971f407f95a5c818f60f7b0482f2fa
3
+ size 14503
tokenizer_config.json CHANGED
@@ -104,7 +104,7 @@
104
  "eos_token": "</s>",
105
  "extra_ids": 100,
106
  "model_max_length": 512,
107
- "name_or_path": "pszemraj/grammar-synthesis-base-V1",
108
  "pad_token": "<pad>",
109
  "sp_model_kwargs": {},
110
  "special_tokens_map_file": "/home/patrick/.cache/huggingface/transformers/76bf19bfedb85afbe644966ca9ab7b0404d753a41bf601115bced39f825ffa9c.c94798918c92ded6aeef2d2f0e666d2cc4145eca1aa6e1336fde07f2e13e2f46",
 
104
  "eos_token": "</s>",
105
  "extra_ids": 100,
106
  "model_max_length": 512,
107
+ "name_or_path": "pszemraj/grammar-synthesis-base-V2",
108
  "pad_token": "<pad>",
109
  "sp_model_kwargs": {},
110
  "special_tokens_map_file": "/home/patrick/.cache/huggingface/transformers/76bf19bfedb85afbe644966ca9ab7b0404d753a41bf601115bced39f825ffa9c.c94798918c92ded6aeef2d2f0e666d2cc4145eca1aa6e1336fde07f2e13e2f46",
trainer_state.json CHANGED
@@ -1,1075 +1,1066 @@
1
  {
2
  "best_metric": null,
3
  "best_model_checkpoint": null,
4
- "epoch": 1.0,
5
- "global_step": 351,
6
  "is_hyper_param_search": false,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [
10
  {
11
  "epoch": 0.01,
12
- "learning_rate": 5e-05,
13
- "loss": 0.1291,
14
  "step": 2
15
  },
16
  {
17
  "epoch": 0.01,
18
- "learning_rate": 0.0001,
19
- "loss": 0.1354,
20
  "step": 4
21
  },
22
  {
23
  "epoch": 0.02,
24
- "learning_rate": 0.00015000000000000001,
25
- "loss": 0.1399,
26
  "step": 6
27
  },
28
  {
29
  "epoch": 0.02,
30
- "learning_rate": 0.0002,
31
- "loss": 0.143,
32
  "step": 8
33
  },
34
  {
35
  "epoch": 0.03,
36
- "learning_rate": 0.0001999832224185224,
37
- "loss": 0.1455,
38
  "step": 10
39
  },
40
  {
41
  "epoch": 0.03,
42
- "learning_rate": 0.00019993289530383432,
43
- "loss": 0.1346,
44
  "step": 12
45
  },
46
  {
47
  "epoch": 0.04,
48
- "learning_rate": 0.00019984903554328114,
49
- "loss": 0.1458,
50
  "step": 14
51
  },
52
  {
53
  "epoch": 0.05,
54
- "learning_rate": 0.00019973167127614215,
55
- "loss": 0.1354,
56
  "step": 16
57
  },
58
  {
59
  "epoch": 0.05,
60
- "learning_rate": 0.0001995808418841885,
61
- "loss": 0.1248,
62
  "step": 18
63
  },
64
  {
65
  "epoch": 0.06,
66
- "learning_rate": 0.0001993965979784684,
67
- "loss": 0.1322,
68
  "step": 20
69
  },
70
  {
71
  "epoch": 0.06,
72
- "learning_rate": 0.0001991790013823246,
73
- "loss": 0.1371,
74
  "step": 22
75
  },
76
  {
77
  "epoch": 0.07,
78
- "learning_rate": 0.0001989281251106496,
79
- "loss": 0.1195,
80
  "step": 24
81
  },
82
  {
83
  "epoch": 0.07,
84
- "learning_rate": 0.00019864405334538517,
85
- "loss": 0.1193,
86
  "step": 26
87
  },
88
  {
89
  "epoch": 0.08,
90
- "learning_rate": 0.000198326881407275,
91
- "loss": 0.1313,
92
  "step": 28
93
  },
94
  {
95
  "epoch": 0.09,
96
- "learning_rate": 0.00019797671572387984,
97
- "loss": 0.148,
98
  "step": 30
99
  },
100
  {
101
  "epoch": 0.09,
102
- "learning_rate": 0.0001975936737938653,
103
- "loss": 0.1384,
104
  "step": 32
105
  },
106
  {
107
  "epoch": 0.1,
108
- "learning_rate": 0.00019717788414757523,
109
- "loss": 0.1317,
110
  "step": 34
111
  },
112
  {
113
  "epoch": 0.1,
114
- "learning_rate": 0.00019672948630390294,
115
- "loss": 0.1364,
116
  "step": 36
117
  },
118
  {
119
  "epoch": 0.11,
120
- "learning_rate": 0.00019624863072347564,
121
- "loss": 0.127,
122
  "step": 38
123
  },
124
  {
125
  "epoch": 0.11,
126
- "learning_rate": 0.00019573547875816684,
127
- "loss": 0.1281,
128
  "step": 40
129
  },
130
  {
131
  "epoch": 0.12,
132
- "learning_rate": 0.0001951902025969548,
133
- "loss": 0.1303,
134
  "step": 42
135
  },
136
  {
137
  "epoch": 0.13,
138
- "learning_rate": 0.0001946129852081439,
139
- "loss": 0.1364,
140
  "step": 44
141
  },
142
  {
143
  "epoch": 0.13,
144
- "learning_rate": 0.00019400402027796955,
145
- "loss": 0.1262,
146
  "step": 46
147
  },
148
  {
149
  "epoch": 0.14,
150
- "learning_rate": 0.00019336351214560647,
151
- "loss": 0.1244,
152
  "step": 48
153
  },
154
  {
155
  "epoch": 0.14,
156
- "learning_rate": 0.0001926916757346022,
157
- "loss": 0.1261,
158
  "step": 50
159
  },
160
  {
161
  "epoch": 0.15,
162
- "learning_rate": 0.0001919887364807592,
163
- "loss": 0.1276,
164
  "step": 52
165
  },
166
  {
167
  "epoch": 0.15,
168
- "learning_rate": 0.00019125493025648962,
169
- "loss": 0.1181,
170
  "step": 54
171
  },
172
  {
173
  "epoch": 0.16,
174
- "learning_rate": 0.00019049050329166778,
175
- "loss": 0.1218,
176
  "step": 56
177
  },
178
  {
179
  "epoch": 0.17,
180
- "learning_rate": 0.0001896957120910074,
181
- "loss": 0.1255,
182
  "step": 58
183
  },
184
  {
185
  "epoch": 0.17,
186
- "learning_rate": 0.00018887082334799097,
187
- "loss": 0.1279,
188
  "step": 60
189
  },
190
  {
191
  "epoch": 0.18,
192
- "learning_rate": 0.00018801611385538047,
193
- "loss": 0.1227,
194
  "step": 62
195
  },
196
  {
197
  "epoch": 0.18,
198
- "learning_rate": 0.00018713187041233896,
199
- "loss": 0.132,
200
  "step": 64
201
  },
202
  {
203
  "epoch": 0.19,
204
- "learning_rate": 0.00018621838972819458,
205
- "loss": 0.1249,
206
  "step": 66
207
  },
208
  {
209
  "epoch": 0.19,
210
- "learning_rate": 0.00018527597832287954,
211
- "loss": 0.1182,
212
  "step": 68
213
  },
214
  {
215
  "epoch": 0.2,
216
- "learning_rate": 0.00018430495242407658,
217
- "loss": 0.1325,
218
  "step": 70
219
  },
220
  {
221
  "epoch": 0.21,
222
- "learning_rate": 0.00018330563786110834,
223
- "loss": 0.1209,
224
  "step": 72
225
  },
226
  {
227
  "epoch": 0.21,
228
- "learning_rate": 0.0001822783699556049,
229
- "loss": 0.1322,
230
  "step": 74
231
  },
232
  {
233
  "epoch": 0.22,
234
- "learning_rate": 0.00018122349340898595,
235
- "loss": 0.1225,
236
  "step": 76
237
  },
238
  {
239
  "epoch": 0.22,
240
- "learning_rate": 0.00018014136218679567,
241
- "loss": 0.1209,
242
  "step": 78
243
  },
244
  {
245
  "epoch": 0.23,
246
- "learning_rate": 0.00017903233939992906,
247
- "loss": 0.125,
248
  "step": 80
249
  },
250
  {
251
  "epoch": 0.23,
252
- "learning_rate": 0.00017789679718278943,
253
- "loss": 0.1232,
254
  "step": 82
255
  },
256
  {
257
  "epoch": 0.24,
258
- "learning_rate": 0.00017673511656841822,
259
- "loss": 0.1199,
260
  "step": 84
261
  },
262
  {
263
  "epoch": 0.25,
264
- "learning_rate": 0.00017554768736063859,
265
- "loss": 0.1228,
266
  "step": 86
267
  },
268
  {
269
  "epoch": 0.25,
270
- "learning_rate": 0.00017433490800325614,
271
- "loss": 0.1303,
272
  "step": 88
273
  },
274
  {
275
  "epoch": 0.26,
276
- "learning_rate": 0.00017309718544636057,
277
- "loss": 0.12,
278
  "step": 90
279
  },
280
  {
281
  "epoch": 0.26,
282
- "learning_rate": 0.00017183493500977278,
283
- "loss": 0.1164,
284
  "step": 92
285
  },
286
  {
287
  "epoch": 0.27,
288
- "learning_rate": 0.00017054858024368366,
289
- "loss": 0.123,
290
  "step": 94
291
  },
292
  {
293
  "epoch": 0.27,
294
- "learning_rate": 0.00016923855278653114,
295
- "loss": 0.1156,
296
  "step": 96
297
  },
298
  {
299
  "epoch": 0.28,
300
- "learning_rate": 0.00016790529222016328,
301
- "loss": 0.1199,
302
  "step": 98
303
  },
304
  {
305
  "epoch": 0.28,
306
- "learning_rate": 0.00016654924592233568,
307
- "loss": 0.115,
308
  "step": 100
309
  },
310
  {
311
  "epoch": 0.29,
312
- "learning_rate": 0.00016517086891659335,
313
- "loss": 0.1248,
314
  "step": 102
315
  },
316
  {
317
  "epoch": 0.3,
318
- "learning_rate": 0.00016377062371958668,
319
- "loss": 0.1229,
320
  "step": 104
321
  },
322
  {
323
  "epoch": 0.3,
324
- "learning_rate": 0.00016234898018587337,
325
- "loss": 0.1188,
326
  "step": 106
327
  },
328
  {
329
  "epoch": 0.31,
330
- "learning_rate": 0.00016090641535025774,
331
- "loss": 0.1209,
332
  "step": 108
333
  },
334
  {
335
  "epoch": 0.31,
336
- "learning_rate": 0.00015944341326772112,
337
- "loss": 0.1197,
338
  "step": 110
339
  },
340
  {
341
  "epoch": 0.32,
342
- "learning_rate": 0.00015796046485099633,
343
- "loss": 0.1192,
344
  "step": 112
345
  },
346
  {
347
  "epoch": 0.32,
348
- "learning_rate": 0.0001564580677058412,
349
- "loss": 0.1124,
350
  "step": 114
351
  },
352
  {
353
  "epoch": 0.33,
354
- "learning_rate": 0.00015493672596406598,
355
- "loss": 0.1189,
356
  "step": 116
357
  },
358
  {
359
  "epoch": 0.34,
360
- "learning_rate": 0.00015339695011437127,
361
- "loss": 0.1123,
362
  "step": 118
363
  },
364
  {
365
  "epoch": 0.34,
366
- "learning_rate": 0.00015183925683105254,
367
- "loss": 0.1139,
368
  "step": 120
369
  },
370
  {
371
  "epoch": 0.35,
372
- "learning_rate": 0.00015026416880062931,
373
- "loss": 0.1165,
374
  "step": 122
375
  },
376
  {
377
  "epoch": 0.35,
378
- "learning_rate": 0.00014867221454645696,
379
- "loss": 0.1265,
380
  "step": 124
381
  },
382
  {
383
  "epoch": 0.36,
384
- "learning_rate": 0.00014706392825137964,
385
- "loss": 0.1179,
386
  "step": 126
387
  },
388
  {
389
  "epoch": 0.36,
390
- "learning_rate": 0.0001454398495784844,
391
- "loss": 0.1279,
392
  "step": 128
393
  },
394
  {
395
  "epoch": 0.37,
396
- "learning_rate": 0.00014380052349001647,
397
- "loss": 0.106,
398
  "step": 130
399
  },
400
  {
401
  "epoch": 0.38,
402
- "learning_rate": 0.00014214650006451622,
403
- "loss": 0.118,
404
  "step": 132
405
  },
406
  {
407
  "epoch": 0.38,
408
- "learning_rate": 0.00014047833431223938,
409
- "loss": 0.1276,
410
  "step": 134
411
  },
412
  {
413
  "epoch": 0.39,
414
- "learning_rate": 0.00013879658598892254,
415
- "loss": 0.1255,
416
  "step": 136
417
  },
418
  {
419
  "epoch": 0.39,
420
- "learning_rate": 0.000137101819407956,
421
- "loss": 0.1191,
422
  "step": 138
423
  },
424
  {
425
  "epoch": 0.4,
426
- "learning_rate": 0.00013539460325102777,
427
- "loss": 0.111,
428
  "step": 140
429
  },
430
  {
431
  "epoch": 0.4,
432
- "learning_rate": 0.00013367551037730128,
433
- "loss": 0.1251,
434
  "step": 142
435
  },
436
  {
437
  "epoch": 0.41,
438
- "learning_rate": 0.00013194511763119172,
439
- "loss": 0.1193,
440
  "step": 144
441
  },
442
  {
443
  "epoch": 0.42,
444
- "learning_rate": 0.0001302040056488047,
445
- "loss": 0.11,
446
  "step": 146
447
  },
448
  {
449
  "epoch": 0.42,
450
- "learning_rate": 0.00012845275866310324,
451
- "loss": 0.1033,
452
  "step": 148
453
  },
454
  {
455
  "epoch": 0.43,
456
- "learning_rate": 0.00012669196430786713,
457
- "loss": 0.1149,
458
  "step": 150
459
  },
460
  {
461
  "epoch": 0.43,
462
- "learning_rate": 0.00012492221342051154,
463
- "loss": 0.1231,
464
  "step": 152
465
  },
466
  {
467
  "epoch": 0.44,
468
- "learning_rate": 0.00012314409984383066,
469
- "loss": 0.1244,
470
  "step": 154
471
  },
472
  {
473
  "epoch": 0.44,
474
- "learning_rate": 0.00012135822022673263,
475
- "loss": 0.1125,
476
  "step": 156
477
  },
478
  {
479
  "epoch": 0.45,
480
- "learning_rate": 0.00011956517382403321,
481
- "loss": 0.1102,
482
  "step": 158
483
  },
484
  {
485
  "epoch": 0.46,
486
- "learning_rate": 0.0001177655622953746,
487
- "loss": 0.1175,
488
  "step": 160
489
  },
490
  {
491
  "epoch": 0.46,
492
- "learning_rate": 0.00011595998950333793,
493
- "loss": 0.116,
494
  "step": 162
495
  },
496
  {
497
  "epoch": 0.47,
498
- "learning_rate": 0.00011414906131081575,
499
- "loss": 0.1144,
500
  "step": 164
501
  },
502
  {
503
  "epoch": 0.47,
504
- "learning_rate": 0.00011233338537771407,
505
- "loss": 0.1229,
506
  "step": 166
507
  },
508
  {
509
  "epoch": 0.48,
510
- "learning_rate": 0.00011051357095705101,
511
- "loss": 0.1117,
512
  "step": 168
513
  },
514
  {
515
  "epoch": 0.48,
516
- "learning_rate": 0.0001086902286905209,
517
- "loss": 0.1144,
518
  "step": 170
519
  },
520
  {
521
  "epoch": 0.49,
522
- "learning_rate": 0.00010686397040359253,
523
- "loss": 0.109,
524
  "step": 172
525
  },
526
  {
527
  "epoch": 0.5,
528
- "learning_rate": 0.00010503540890020997,
529
- "loss": 0.1151,
530
  "step": 174
531
  },
532
  {
533
  "epoch": 0.5,
534
- "learning_rate": 0.00010320515775716555,
535
- "loss": 0.1197,
536
  "step": 176
537
  },
538
  {
539
  "epoch": 0.51,
540
- "learning_rate": 0.00010137383111821266,
541
- "loss": 0.1073,
542
  "step": 178
543
  },
544
  {
545
  "epoch": 0.51,
546
- "learning_rate": 9.954204348798938e-05,
547
- "loss": 0.1114,
548
  "step": 180
549
  },
550
  {
551
  "epoch": 0.52,
552
- "learning_rate": 9.771040952581998e-05,
553
- "loss": 0.1158,
554
  "step": 182
555
  },
556
  {
557
  "epoch": 0.52,
558
- "learning_rate": 9.587954383946517e-05,
559
- "loss": 0.1103,
560
  "step": 184
561
  },
562
  {
563
  "epoch": 0.53,
564
- "learning_rate": 9.405006077888954e-05,
565
- "loss": 0.1083,
566
  "step": 186
567
  },
568
  {
569
  "epoch": 0.54,
570
- "learning_rate": 9.22225742301153e-05,
571
- "loss": 0.1137,
572
  "step": 188
573
  },
574
  {
575
  "epoch": 0.54,
576
- "learning_rate": 9.039769740923183e-05,
577
- "loss": 0.1123,
578
  "step": 190
579
  },
580
  {
581
  "epoch": 0.55,
582
- "learning_rate": 8.857604265663017e-05,
583
- "loss": 0.1191,
584
  "step": 192
585
  },
586
  {
587
  "epoch": 0.55,
588
- "learning_rate": 8.675822123153103e-05,
589
- "loss": 0.1119,
590
  "step": 194
591
  },
592
  {
593
  "epoch": 0.56,
594
- "learning_rate": 8.494484310687581e-05,
595
- "loss": 0.1196,
596
  "step": 196
597
  },
598
  {
599
  "epoch": 0.56,
600
- "learning_rate": 8.313651676464923e-05,
601
- "loss": 0.1135,
602
  "step": 198
603
  },
604
  {
605
  "epoch": 0.57,
606
- "learning_rate": 8.133384899170225e-05,
607
- "loss": 0.1092,
608
  "step": 200
609
  },
610
  {
611
  "epoch": 0.58,
612
- "learning_rate": 7.953744467614354e-05,
613
- "loss": 0.112,
614
  "step": 202
615
  },
616
  {
617
  "epoch": 0.58,
618
- "learning_rate": 7.774790660436858e-05,
619
- "loss": 0.1099,
620
  "step": 204
621
  },
622
  {
623
  "epoch": 0.59,
624
- "learning_rate": 7.596583525879344e-05,
625
- "loss": 0.1064,
626
  "step": 206
627
  },
628
  {
629
  "epoch": 0.59,
630
- "learning_rate": 7.419182861636218e-05,
631
- "loss": 0.1165,
632
  "step": 208
633
  },
634
  {
635
  "epoch": 0.6,
636
- "learning_rate": 7.242648194789446e-05,
637
- "loss": 0.1215,
638
  "step": 210
639
  },
640
  {
641
  "epoch": 0.6,
642
- "learning_rate": 7.067038761834164e-05,
643
- "loss": 0.1025,
644
  "step": 212
645
  },
646
  {
647
  "epoch": 0.61,
648
- "learning_rate": 6.89241348880176e-05,
649
- "loss": 0.1109,
650
  "step": 214
651
  },
652
  {
653
  "epoch": 0.62,
654
- "learning_rate": 6.718830971487165e-05,
655
- "loss": 0.1157,
656
  "step": 216
657
  },
658
  {
659
  "epoch": 0.62,
660
- "learning_rate": 6.546349455786926e-05,
661
- "loss": 0.1108,
662
  "step": 218
663
  },
664
  {
665
  "epoch": 0.63,
666
- "learning_rate": 6.3750268181547e-05,
667
- "loss": 0.1043,
668
  "step": 220
669
  },
670
  {
671
  "epoch": 0.63,
672
- "learning_rate": 6.204920546180728e-05,
673
- "loss": 0.108,
674
  "step": 222
675
  },
676
  {
677
  "epoch": 0.64,
678
- "learning_rate": 6.036087719301763e-05,
679
- "loss": 0.1125,
680
  "step": 224
681
  },
682
  {
683
  "epoch": 0.64,
684
- "learning_rate": 5.868584989647994e-05,
685
- "loss": 0.1096,
686
  "step": 226
687
  },
688
  {
689
  "epoch": 0.65,
690
- "learning_rate": 5.702468563033306e-05,
691
- "loss": 0.1114,
692
  "step": 228
693
  },
694
  {
695
  "epoch": 0.66,
696
- "learning_rate": 5.5377941800953416e-05,
697
- "loss": 0.0981,
698
  "step": 230
699
  },
700
  {
701
  "epoch": 0.66,
702
- "learning_rate": 5.37461709759165e-05,
703
- "loss": 0.1147,
704
  "step": 232
705
  },
706
  {
707
  "epoch": 0.67,
708
- "learning_rate": 5.2129920698581606e-05,
709
- "loss": 0.1085,
710
  "step": 234
711
  },
712
  {
713
  "epoch": 0.67,
714
- "learning_rate": 5.0529733304363145e-05,
715
- "loss": 0.1133,
716
  "step": 236
717
  },
718
  {
719
  "epoch": 0.68,
720
- "learning_rate": 4.894614573874877e-05,
721
- "loss": 0.1036,
722
  "step": 238
723
  },
724
  {
725
  "epoch": 0.68,
726
- "learning_rate": 4.7379689377126735e-05,
727
- "loss": 0.1106,
728
  "step": 240
729
  },
730
  {
731
  "epoch": 0.69,
732
- "learning_rate": 4.583088984648172e-05,
733
- "loss": 0.1185,
734
  "step": 242
735
  },
736
  {
737
  "epoch": 0.7,
738
- "learning_rate": 4.430026684902017e-05,
739
- "loss": 0.1127,
740
  "step": 244
741
  },
742
  {
743
  "epoch": 0.7,
744
- "learning_rate": 4.278833398778306e-05,
745
- "loss": 0.1187,
746
  "step": 246
747
  },
748
  {
749
  "epoch": 0.71,
750
- "learning_rate": 4.129559859430573e-05,
751
- "loss": 0.12,
752
  "step": 248
753
  },
754
  {
755
  "epoch": 0.71,
756
- "learning_rate": 3.982256155838199e-05,
757
- "loss": 0.1159,
758
  "step": 250
759
  },
760
  {
761
  "epoch": 0.72,
762
- "learning_rate": 3.836971715998967e-05,
763
- "loss": 0.1278,
764
  "step": 252
765
  },
766
  {
767
  "epoch": 0.72,
768
- "learning_rate": 3.693755290343409e-05,
769
- "loss": 0.1188,
770
  "step": 254
771
  },
772
  {
773
  "epoch": 0.73,
774
- "learning_rate": 3.5526549353765296e-05,
775
- "loss": 0.1091,
776
  "step": 256
777
  },
778
  {
779
  "epoch": 0.74,
780
- "learning_rate": 3.413717997552376e-05,
781
- "loss": 0.1014,
782
  "step": 258
783
  },
784
  {
785
  "epoch": 0.74,
786
- "learning_rate": 3.276991097386831e-05,
787
- "loss": 0.1147,
788
  "step": 260
789
  },
790
  {
791
  "epoch": 0.75,
792
- "learning_rate": 3.142520113814059e-05,
793
- "loss": 0.1144,
794
  "step": 262
795
  },
796
  {
797
  "epoch": 0.75,
798
- "learning_rate": 3.010350168791719e-05,
799
- "loss": 0.1161,
800
  "step": 264
801
  },
802
  {
803
  "epoch": 0.76,
804
- "learning_rate": 2.8805256121602398e-05,
805
- "loss": 0.1145,
806
  "step": 266
807
  },
808
  {
809
  "epoch": 0.76,
810
- "learning_rate": 2.7530900067611577e-05,
811
- "loss": 0.1151,
812
  "step": 268
813
  },
814
  {
815
  "epoch": 0.77,
816
- "learning_rate": 2.62808611381953e-05,
817
- "loss": 0.1149,
818
  "step": 270
819
  },
820
  {
821
  "epoch": 0.77,
822
- "learning_rate": 2.5055558785953303e-05,
823
- "loss": 0.114,
824
  "step": 272
825
  },
826
  {
827
  "epoch": 0.78,
828
- "learning_rate": 2.3855404163086558e-05,
829
- "loss": 0.1187,
830
  "step": 274
831
  },
832
  {
833
  "epoch": 0.79,
834
- "learning_rate": 2.268079998343453e-05,
835
- "loss": 0.1108,
836
  "step": 276
837
  },
838
  {
839
  "epoch": 0.79,
840
- "learning_rate": 2.1532140387343735e-05,
841
- "loss": 0.1112,
842
  "step": 278
843
  },
844
  {
845
  "epoch": 0.8,
846
- "learning_rate": 2.0409810809413486e-05,
847
- "loss": 0.1132,
848
  "step": 280
849
  },
850
  {
851
  "epoch": 0.8,
852
- "learning_rate": 1.9314187849162524e-05,
853
- "loss": 0.1098,
854
  "step": 282
855
  },
856
  {
857
  "epoch": 0.81,
858
- "learning_rate": 1.8245639144660532e-05,
859
- "loss": 0.1073,
860
  "step": 284
861
  },
862
  {
863
  "epoch": 0.81,
864
- "learning_rate": 1.720452324916656e-05,
865
- "loss": 0.1079,
866
  "step": 286
867
  },
868
  {
869
  "epoch": 0.82,
870
- "learning_rate": 1.619118951081594e-05,
871
- "loss": 0.1068,
872
  "step": 288
873
  },
874
  {
875
  "epoch": 0.83,
876
- "learning_rate": 1.5205977955395812e-05,
877
- "loss": 0.1162,
878
  "step": 290
879
  },
880
  {
881
  "epoch": 0.83,
882
- "learning_rate": 1.424921917224905e-05,
883
- "loss": 0.111,
884
  "step": 292
885
  },
886
  {
887
  "epoch": 0.84,
888
- "learning_rate": 1.3321234203344435e-05,
889
- "loss": 0.1145,
890
  "step": 294
891
  },
892
  {
893
  "epoch": 0.84,
894
- "learning_rate": 1.2422334435550509e-05,
895
- "loss": 0.1137,
896
  "step": 296
897
  },
898
  {
899
  "epoch": 0.85,
900
- "learning_rate": 1.1552821496149135e-05,
901
- "loss": 0.1009,
902
  "step": 298
903
  },
904
  {
905
  "epoch": 0.85,
906
- "learning_rate": 1.0712987151624054e-05,
907
- "loss": 0.1055,
908
  "step": 300
909
  },
910
  {
911
  "epoch": 0.86,
912
- "learning_rate": 9.903113209758096e-06,
913
- "loss": 0.1054,
914
  "step": 302
915
  },
916
  {
917
  "epoch": 0.87,
918
- "learning_rate": 9.123471425072206e-06,
919
- "loss": 0.1042,
920
  "step": 304
921
  },
922
  {
923
  "epoch": 0.87,
924
- "learning_rate": 8.374323407637742e-06,
925
- "loss": 0.1032,
926
  "step": 306
927
  },
928
  {
929
  "epoch": 0.88,
930
- "learning_rate": 7.655920535292682e-06,
931
- "loss": 0.104,
932
  "step": 308
933
  },
934
  {
935
  "epoch": 0.88,
936
- "learning_rate": 6.968503869291521e-06,
937
- "loss": 0.1064,
938
  "step": 310
939
  },
940
  {
941
  "epoch": 0.89,
942
- "learning_rate": 6.312304073416719e-06,
943
- "loss": 0.1081,
944
  "step": 312
945
  },
946
  {
947
  "epoch": 0.89,
948
- "learning_rate": 5.687541336579128e-06,
949
- "loss": 0.1151,
950
  "step": 314
951
  },
952
  {
953
  "epoch": 0.9,
954
- "learning_rate": 5.094425298933136e-06,
955
- "loss": 0.1184,
956
  "step": 316
957
  },
958
  {
959
  "epoch": 0.91,
960
- "learning_rate": 4.5331549815317174e-06,
961
- "loss": 0.1061,
962
  "step": 318
963
  },
964
  {
965
  "epoch": 0.91,
966
- "learning_rate": 4.003918719544464e-06,
967
- "loss": 0.1104,
968
  "step": 320
969
  },
970
  {
971
  "epoch": 0.92,
972
- "learning_rate": 3.5068940990615197e-06,
973
- "loss": 0.1022,
974
  "step": 322
975
  },
976
  {
977
  "epoch": 0.92,
978
- "learning_rate": 3.0422478975042246e-06,
979
- "loss": 0.1042,
980
  "step": 324
981
  },
982
  {
983
  "epoch": 0.93,
984
- "learning_rate": 2.6101360276626798e-06,
985
- "loss": 0.1126,
986
  "step": 326
987
  },
988
  {
989
  "epoch": 0.93,
990
- "learning_rate": 2.2107034853789288e-06,
991
- "loss": 0.1084,
992
  "step": 328
993
  },
994
  {
995
  "epoch": 0.94,
996
- "learning_rate": 1.8440843008934561e-06,
997
- "loss": 0.1047,
998
  "step": 330
999
  },
1000
  {
1001
  "epoch": 0.95,
1002
- "learning_rate": 1.5104014938710497e-06,
1003
- "loss": 0.0979,
1004
  "step": 332
1005
  },
1006
  {
1007
  "epoch": 0.95,
1008
- "learning_rate": 1.209767032121345e-06,
1009
- "loss": 0.1147,
1010
  "step": 334
1011
  },
1012
  {
1013
  "epoch": 0.96,
1014
- "learning_rate": 9.422817940278772e-07,
1015
- "loss": 0.1116,
1016
  "step": 336
1017
  },
1018
  {
1019
  "epoch": 0.96,
1020
- "learning_rate": 7.080355346981815e-07,
1021
- "loss": 0.1036,
1022
  "step": 338
1023
  },
1024
  {
1025
  "epoch": 0.97,
1026
- "learning_rate": 5.071068558462732e-07,
1027
- "loss": 0.115,
1028
  "step": 340
1029
  },
1030
  {
1031
  "epoch": 0.97,
1032
- "learning_rate": 3.3956317941779005e-07,
1033
- "loss": 0.108,
1034
  "step": 342
1035
  },
1036
  {
1037
  "epoch": 0.98,
1038
- "learning_rate": 2.054607249663665e-07,
1039
- "loss": 0.1063,
1040
  "step": 344
1041
  },
1042
  {
1043
  "epoch": 0.99,
1044
- "learning_rate": 1.0484449078912439e-07,
1045
- "loss": 0.1028,
1046
  "step": 346
1047
  },
1048
  {
1049
  "epoch": 0.99,
1050
- "learning_rate": 3.774823882738421e-08,
1051
- "loss": 0.1009,
1052
  "step": 348
1053
  },
1054
  {
1055
  "epoch": 1.0,
1056
- "learning_rate": 4.194483337860433e-09,
1057
- "loss": 0.1114,
1058
  "step": 350
1059
- },
1060
- {
1061
- "epoch": 1.0,
1062
- "step": 351,
1063
- "total_flos": 2.4611000956852634e+17,
1064
- "train_loss": 0.11711755568994756,
1065
- "train_runtime": 56238.1049,
1066
- "train_samples_per_second": 3.195,
1067
- "train_steps_per_second": 0.006
1068
  }
1069
  ],
1070
  "max_steps": 351,
1071
  "num_train_epochs": 1,
1072
- "total_flos": 2.4611000956852634e+17,
1073
  "trial_name": null,
1074
  "trial_params": null
1075
  }
 
1
  {
2
  "best_metric": null,
3
  "best_model_checkpoint": null,
4
+ "epoch": 0.9971509971509972,
5
+ "global_step": 350,
6
  "is_hyper_param_search": false,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [
10
  {
11
  "epoch": 0.01,
12
+ "learning_rate": 2e-05,
13
+ "loss": 0.097,
14
  "step": 2
15
  },
16
  {
17
  "epoch": 0.01,
18
+ "learning_rate": 4e-05,
19
+ "loss": 0.1096,
20
  "step": 4
21
  },
22
  {
23
  "epoch": 0.02,
24
+ "learning_rate": 3.999672139632675e-05,
25
+ "loss": 0.098,
26
  "step": 6
27
  },
28
  {
29
  "epoch": 0.02,
30
+ "learning_rate": 3.9986886660231184e-05,
31
+ "loss": 0.0934,
32
  "step": 8
33
  },
34
  {
35
  "epoch": 0.03,
36
+ "learning_rate": 3.997049901613351e-05,
37
+ "loss": 0.1051,
38
  "step": 10
39
  },
40
  {
41
  "epoch": 0.03,
42
+ "learning_rate": 3.9947563836892725e-05,
43
+ "loss": 0.1093,
44
  "step": 12
45
  },
46
  {
47
  "epoch": 0.04,
48
+ "learning_rate": 3.9918088642045126e-05,
49
+ "loss": 0.1113,
50
  "step": 14
51
  },
52
  {
53
  "epoch": 0.05,
54
+ "learning_rate": 3.9882083095338934e-05,
55
+ "loss": 0.1097,
56
  "step": 16
57
  },
58
  {
59
  "epoch": 0.05,
60
+ "learning_rate": 3.98395590015659e-05,
61
+ "loss": 0.0988,
62
  "step": 18
63
  },
64
  {
65
  "epoch": 0.06,
66
+ "learning_rate": 3.979053030269103e-05,
67
+ "loss": 0.1046,
68
  "step": 20
69
  },
70
  {
71
  "epoch": 0.06,
72
+ "learning_rate": 3.9735013073281564e-05,
73
+ "loss": 0.102,
74
  "step": 22
75
  },
76
  {
77
  "epoch": 0.07,
78
+ "learning_rate": 3.967302551523671e-05,
79
+ "loss": 0.1029,
80
  "step": 24
81
  },
82
  {
83
  "epoch": 0.07,
84
+ "learning_rate": 3.960458795182003e-05,
85
+ "loss": 0.1102,
86
  "step": 26
87
  },
88
  {
89
  "epoch": 0.08,
90
+ "learning_rate": 3.95297228209962e-05,
91
+ "loss": 0.1021,
92
  "step": 28
93
  },
94
  {
95
  "epoch": 0.09,
96
+ "learning_rate": 3.944845466807451e-05,
97
+ "loss": 0.1015,
98
  "step": 30
99
  },
100
  {
101
  "epoch": 0.09,
102
+ "learning_rate": 3.936081013766143e-05,
103
+ "loss": 0.109,
104
  "step": 32
105
  },
106
  {
107
  "epoch": 0.1,
108
+ "learning_rate": 3.9266817964924905e-05,
109
+ "loss": 0.1131,
110
  "step": 34
111
  },
112
  {
113
  "epoch": 0.1,
114
+ "learning_rate": 3.91665089661732e-05,
115
+ "loss": 0.1019,
116
  "step": 36
117
  },
118
  {
119
  "epoch": 0.11,
120
+ "learning_rate": 3.9059916028751496e-05,
121
+ "loss": 0.1032,
122
  "step": 38
123
  },
124
  {
125
  "epoch": 0.11,
126
+ "learning_rate": 3.894707410025941e-05,
127
+ "loss": 0.104,
128
  "step": 40
129
  },
130
  {
131
  "epoch": 0.12,
132
+ "learning_rate": 3.882802017709307e-05,
133
+ "loss": 0.1015,
134
  "step": 42
135
  },
136
  {
137
  "epoch": 0.13,
138
+ "learning_rate": 3.870279329231546e-05,
139
+ "loss": 0.106,
140
  "step": 44
141
  },
142
  {
143
  "epoch": 0.13,
144
+ "learning_rate": 3.857143450285901e-05,
145
+ "loss": 0.0968,
146
  "step": 46
147
  },
148
  {
149
  "epoch": 0.14,
150
+ "learning_rate": 3.84339868760647e-05,
151
+ "loss": 0.0982,
152
  "step": 48
153
  },
154
  {
155
  "epoch": 0.14,
156
+ "learning_rate": 3.829049547556193e-05,
157
+ "loss": 0.1069,
158
  "step": 50
159
  },
160
  {
161
  "epoch": 0.15,
162
+ "learning_rate": 3.8141007346493964e-05,
163
+ "loss": 0.1048,
164
  "step": 52
165
  },
166
  {
167
  "epoch": 0.15,
168
+ "learning_rate": 3.798557150009373e-05,
169
+ "loss": 0.1054,
170
  "step": 54
171
  },
172
  {
173
  "epoch": 0.16,
174
+ "learning_rate": 3.782423889761492e-05,
175
+ "loss": 0.1146,
176
  "step": 56
177
  },
178
  {
179
  "epoch": 0.17,
180
+ "learning_rate": 3.7657062433623825e-05,
181
+ "loss": 0.1015,
182
  "step": 58
183
  },
184
  {
185
  "epoch": 0.17,
186
+ "learning_rate": 3.748409691865737e-05,
187
+ "loss": 0.0976,
188
  "step": 60
189
  },
190
  {
191
  "epoch": 0.18,
192
+ "learning_rate": 3.7305399061252795e-05,
193
+ "loss": 0.1108,
194
  "step": 62
195
  },
196
  {
197
  "epoch": 0.18,
198
+ "learning_rate": 3.712102744935529e-05,
199
+ "loss": 0.1041,
200
  "step": 64
201
  },
202
  {
203
  "epoch": 0.19,
204
+ "learning_rate": 3.6931042531109246e-05,
205
+ "loss": 0.1061,
206
  "step": 66
207
  },
208
  {
209
  "epoch": 0.19,
210
+ "learning_rate": 3.673550659503975e-05,
211
+ "loss": 0.0952,
212
  "step": 68
213
  },
214
  {
215
  "epoch": 0.2,
216
+ "learning_rate": 3.6534483749630624e-05,
217
+ "loss": 0.1023,
218
  "step": 70
219
  },
220
  {
221
  "epoch": 0.21,
222
+ "learning_rate": 3.6328039902305806e-05,
223
+ "loss": 0.0961,
224
  "step": 72
225
  },
226
  {
227
  "epoch": 0.21,
228
+ "learning_rate": 3.611624273782092e-05,
229
+ "loss": 0.0971,
230
  "step": 74
231
  },
232
  {
233
  "epoch": 0.22,
234
+ "learning_rate": 3.589916169607209e-05,
235
+ "loss": 0.1019,
236
  "step": 76
237
  },
238
  {
239
  "epoch": 0.22,
240
+ "learning_rate": 3.567686794932943e-05,
241
+ "loss": 0.1016,
242
  "step": 78
243
  },
244
  {
245
  "epoch": 0.23,
246
+ "learning_rate": 3.544943437890238e-05,
247
+ "loss": 0.1057,
248
  "step": 80
249
  },
250
  {
251
  "epoch": 0.23,
252
+ "learning_rate": 3.5216935551244896e-05,
253
+ "loss": 0.104,
254
  "step": 82
255
  },
256
  {
257
  "epoch": 0.24,
258
+ "learning_rate": 3.4979447693508e-05,
259
+ "loss": 0.103,
260
  "step": 84
261
  },
262
  {
263
  "epoch": 0.25,
264
+ "learning_rate": 3.4737048668547995e-05,
265
+ "loss": 0.1039,
266
  "step": 86
267
  },
268
  {
269
  "epoch": 0.25,
270
+ "learning_rate": 3.4489817949398224e-05,
271
+ "loss": 0.0955,
272
  "step": 88
273
  },
274
  {
275
  "epoch": 0.26,
276
+ "learning_rate": 3.423783659321307e-05,
277
+ "loss": 0.1059,
278
  "step": 90
279
  },
280
  {
281
  "epoch": 0.26,
282
+ "learning_rate": 3.398118721469255e-05,
283
+ "loss": 0.1028,
284
  "step": 92
285
  },
286
  {
287
  "epoch": 0.27,
288
+ "learning_rate": 3.371995395899618e-05,
289
+ "loss": 0.1056,
290
  "step": 94
291
  },
292
  {
293
  "epoch": 0.27,
294
+ "learning_rate": 3.345422247415512e-05,
295
+ "loss": 0.0985,
296
  "step": 96
297
  },
298
  {
299
  "epoch": 0.28,
300
+ "learning_rate": 3.3184079882991606e-05,
301
+ "loss": 0.106,
302
  "step": 98
303
  },
304
  {
305
  "epoch": 0.28,
306
+ "learning_rate": 3.29096147545548e-05,
307
+ "loss": 0.0983,
308
  "step": 100
309
  },
310
  {
311
  "epoch": 0.29,
312
+ "learning_rate": 3.2630917075082545e-05,
313
+ "loss": 0.0979,
314
  "step": 102
315
  },
316
  {
317
  "epoch": 0.3,
318
+ "learning_rate": 3.234807821849838e-05,
319
+ "loss": 0.0987,
320
  "step": 104
321
  },
322
  {
323
  "epoch": 0.3,
324
+ "learning_rate": 3.2061190916453745e-05,
325
+ "loss": 0.1096,
326
  "step": 106
327
  },
328
  {
329
  "epoch": 0.31,
330
+ "learning_rate": 3.1770349227924854e-05,
331
+ "loss": 0.11,
332
  "step": 108
333
  },
334
  {
335
  "epoch": 0.31,
336
+ "learning_rate": 3.147564850837455e-05,
337
+ "loss": 0.1004,
338
  "step": 110
339
  },
340
  {
341
  "epoch": 0.32,
342
+ "learning_rate": 3.1177185378488984e-05,
343
+ "loss": 0.0939,
344
  "step": 112
345
  },
346
  {
347
  "epoch": 0.32,
348
+ "learning_rate": 3.0875057692499566e-05,
349
+ "loss": 0.0944,
350
  "step": 114
351
  },
352
  {
353
  "epoch": 0.33,
354
+ "learning_rate": 3.05693645061004e-05,
355
+ "loss": 0.1117,
356
  "step": 116
357
  },
358
  {
359
  "epoch": 0.34,
360
+ "learning_rate": 3.0260206043971857e-05,
361
+ "loss": 0.0962,
362
  "step": 118
363
  },
364
  {
365
  "epoch": 0.34,
366
+ "learning_rate": 2.9947683666920913e-05,
367
+ "loss": 0.0993,
368
  "step": 120
369
  },
370
  {
371
  "epoch": 0.35,
372
+ "learning_rate": 2.9631899838648887e-05,
373
+ "loss": 0.0946,
374
  "step": 122
375
  },
376
  {
377
  "epoch": 0.35,
378
+ "learning_rate": 2.9312958092157724e-05,
379
+ "loss": 0.1003,
380
  "step": 124
381
  },
382
  {
383
  "epoch": 0.36,
384
+ "learning_rate": 2.8990962995805577e-05,
385
+ "loss": 0.1009,
386
  "step": 126
387
  },
388
  {
389
  "epoch": 0.36,
390
+ "learning_rate": 2.866602011902301e-05,
391
+ "loss": 0.0913,
392
  "step": 128
393
  },
394
  {
395
  "epoch": 0.37,
396
+ "learning_rate": 2.833823599770098e-05,
397
+ "loss": 0.0961,
398
  "step": 130
399
  },
400
  {
401
  "epoch": 0.38,
402
+ "learning_rate": 2.8007718099261886e-05,
403
+ "loss": 0.1042,
404
  "step": 132
405
  },
406
  {
407
  "epoch": 0.38,
408
+ "learning_rate": 2.767457478742533e-05,
409
+ "loss": 0.1049,
410
  "step": 134
411
  },
412
  {
413
  "epoch": 0.39,
414
+ "learning_rate": 2.733891528667991e-05,
415
+ "loss": 0.1063,
416
  "step": 136
417
  },
418
  {
419
  "epoch": 0.39,
420
+ "learning_rate": 2.7000849646472826e-05,
421
+ "loss": 0.1028,
422
  "step": 138
423
  },
424
  {
425
  "epoch": 0.4,
426
+ "learning_rate": 2.6660488705129054e-05,
427
+ "loss": 0.0973,
428
  "step": 140
429
  },
430
  {
431
  "epoch": 0.4,
432
+ "learning_rate": 2.6317944053511853e-05,
433
+ "loss": 0.1022,
434
  "step": 142
435
  },
436
  {
437
  "epoch": 0.41,
438
+ "learning_rate": 2.5973327998436527e-05,
439
+ "loss": 0.1044,
440
  "step": 144
441
  },
442
  {
443
  "epoch": 0.42,
444
+ "learning_rate": 2.562675352584947e-05,
445
+ "loss": 0.104,
446
  "step": 146
447
  },
448
  {
449
  "epoch": 0.42,
450
+ "learning_rate": 2.5278334263784587e-05,
451
+ "loss": 0.1015,
452
  "step": 148
453
  },
454
  {
455
  "epoch": 0.43,
456
+ "learning_rate": 2.4928184445109108e-05,
457
+ "loss": 0.1026,
458
  "step": 150
459
  },
460
  {
461
  "epoch": 0.43,
462
+ "learning_rate": 2.457641887007121e-05,
463
+ "loss": 0.1108,
464
  "step": 152
465
  },
466
  {
467
  "epoch": 0.44,
468
+ "learning_rate": 2.4223152868661535e-05,
469
+ "loss": 0.104,
470
  "step": 154
471
  },
472
  {
473
  "epoch": 0.44,
474
+ "learning_rate": 2.3868502262801065e-05,
475
+ "loss": 0.1013,
476
  "step": 156
477
  },
478
  {
479
  "epoch": 0.45,
480
+ "learning_rate": 2.3512583328367717e-05,
481
+ "loss": 0.1004,
482
  "step": 158
483
  },
484
  {
485
  "epoch": 0.46,
486
+ "learning_rate": 2.3155512757074065e-05,
487
+ "loss": 0.0986,
488
  "step": 160
489
  },
490
  {
491
  "epoch": 0.46,
492
+ "learning_rate": 2.2797407618208784e-05,
493
+ "loss": 0.0986,
494
  "step": 162
495
  },
496
  {
497
  "epoch": 0.47,
498
+ "learning_rate": 2.2438385320254234e-05,
499
+ "loss": 0.1106,
500
  "step": 164
501
  },
502
  {
503
  "epoch": 0.47,
504
+ "learning_rate": 2.2078563572392907e-05,
505
+ "loss": 0.097,
506
  "step": 166
507
  },
508
  {
509
  "epoch": 0.48,
510
+ "learning_rate": 2.171806034591522e-05,
511
+ "loss": 0.0981,
512
  "step": 168
513
  },
514
  {
515
  "epoch": 0.48,
516
+ "learning_rate": 2.135699383554144e-05,
517
+ "loss": 0.1088,
518
  "step": 170
519
  },
520
  {
521
  "epoch": 0.49,
522
+ "learning_rate": 2.099548242067028e-05,
523
+ "loss": 0.0911,
524
  "step": 172
525
  },
526
  {
527
  "epoch": 0.5,
528
+ "learning_rate": 2.0633644626567007e-05,
529
+ "loss": 0.0978,
530
  "step": 174
531
  },
532
  {
533
  "epoch": 0.5,
534
+ "learning_rate": 2.0271599085503722e-05,
535
+ "loss": 0.0912,
536
  "step": 176
537
  },
538
  {
539
  "epoch": 0.51,
540
+ "learning_rate": 1.9909464497864487e-05,
541
+ "loss": 0.107,
542
  "step": 178
543
  },
544
  {
545
  "epoch": 0.51,
546
+ "learning_rate": 1.954735959322825e-05,
547
+ "loss": 0.1027,
548
  "step": 180
549
  },
550
  {
551
  "epoch": 0.52,
552
+ "learning_rate": 1.9185403091442044e-05,
553
+ "loss": 0.1048,
554
  "step": 182
555
  },
556
  {
557
  "epoch": 0.52,
558
+ "learning_rate": 1.882371366369749e-05,
559
+ "loss": 0.0915,
560
  "step": 184
561
  },
562
  {
563
  "epoch": 0.53,
564
+ "learning_rate": 1.846240989362325e-05,
565
+ "loss": 0.1041,
566
  "step": 186
567
  },
568
  {
569
  "epoch": 0.54,
570
+ "learning_rate": 1.810161023840607e-05,
571
+ "loss": 0.1001,
572
  "step": 188
573
  },
574
  {
575
  "epoch": 0.54,
576
+ "learning_rate": 1.774143298995346e-05,
577
+ "loss": 0.1043,
578
  "step": 190
579
  },
580
  {
581
  "epoch": 0.55,
582
+ "learning_rate": 1.7381996236110386e-05,
583
+ "loss": 0.1067,
584
  "step": 192
585
  },
586
  {
587
  "epoch": 0.55,
588
+ "learning_rate": 1.702341782194301e-05,
589
+ "loss": 0.1095,
590
  "step": 194
591
  },
592
  {
593
  "epoch": 0.56,
594
+ "learning_rate": 1.6665815311101896e-05,
595
+ "loss": 0.1016,
596
  "step": 196
597
  },
598
  {
599
  "epoch": 0.56,
600
+ "learning_rate": 1.630930594727762e-05,
601
+ "loss": 0.0963,
602
  "step": 198
603
  },
604
  {
605
  "epoch": 0.57,
606
+ "learning_rate": 1.5954006615761158e-05,
607
+ "loss": 0.1036,
608
  "step": 200
609
  },
610
  {
611
  "epoch": 0.58,
612
+ "learning_rate": 1.560003380512185e-05,
613
+ "loss": 0.1136,
614
  "step": 202
615
  },
616
  {
617
  "epoch": 0.58,
618
+ "learning_rate": 1.5247503569015413e-05,
619
+ "loss": 0.0947,
620
  "step": 204
621
  },
622
  {
623
  "epoch": 0.59,
624
+ "learning_rate": 1.489653148813455e-05,
625
+ "loss": 0.1105,
626
  "step": 206
627
  },
628
  {
629
  "epoch": 0.59,
630
+ "learning_rate": 1.4547232632314624e-05,
631
+ "loss": 0.1033,
632
  "step": 208
633
  },
634
  {
635
  "epoch": 0.6,
636
+ "learning_rate": 1.4199721522806807e-05,
637
+ "loss": 0.102,
638
  "step": 210
639
  },
640
  {
641
  "epoch": 0.6,
642
+ "learning_rate": 1.3854112094731116e-05,
643
+ "loss": 0.1037,
644
  "step": 212
645
  },
646
  {
647
  "epoch": 0.61,
648
+ "learning_rate": 1.3510517659721583e-05,
649
+ "loss": 0.1005,
650
  "step": 214
651
  },
652
  {
653
  "epoch": 0.62,
654
+ "learning_rate": 1.316905086877589e-05,
655
+ "loss": 0.0979,
656
  "step": 216
657
  },
658
  {
659
  "epoch": 0.62,
660
+ "learning_rate": 1.2829823675321535e-05,
661
+ "loss": 0.1007,
662
  "step": 218
663
  },
664
  {
665
  "epoch": 0.63,
666
+ "learning_rate": 1.2492947298510783e-05,
667
+ "loss": 0.1002,
668
  "step": 220
669
  },
670
  {
671
  "epoch": 0.63,
672
+ "learning_rate": 1.2158532186756275e-05,
673
+ "loss": 0.1037,
674
  "step": 222
675
  },
676
  {
677
  "epoch": 0.64,
678
+ "learning_rate": 1.182668798151939e-05,
679
+ "loss": 0.1018,
680
  "step": 224
681
  },
682
  {
683
  "epoch": 0.64,
684
+ "learning_rate": 1.1497523481363146e-05,
685
+ "loss": 0.1002,
686
  "step": 226
687
  },
688
  {
689
  "epoch": 0.65,
690
+ "learning_rate": 1.1171146606281482e-05,
691
+ "loss": 0.0982,
692
  "step": 228
693
  },
694
  {
695
  "epoch": 0.66,
696
+ "learning_rate": 1.0847664362316549e-05,
697
+ "loss": 0.102,
698
  "step": 230
699
  },
700
  {
701
  "epoch": 0.66,
702
+ "learning_rate": 1.0527182806475662e-05,
703
+ "loss": 0.0928,
704
  "step": 232
705
  },
706
  {
707
  "epoch": 0.67,
708
+ "learning_rate": 1.020980701195946e-05,
709
+ "loss": 0.0996,
710
  "step": 234
711
  },
712
  {
713
  "epoch": 0.67,
714
+ "learning_rate": 9.895641033712507e-06,
715
+ "loss": 0.1014,
716
  "step": 236
717
  },
718
  {
719
  "epoch": 0.68,
720
+ "learning_rate": 9.584787874307828e-06,
721
+ "loss": 0.0994,
722
  "step": 238
723
  },
724
  {
725
  "epoch": 0.68,
726
+ "learning_rate": 9.277349450176445e-06,
727
+ "loss": 0.1092,
728
  "step": 240
729
  },
730
  {
731
  "epoch": 0.69,
732
+ "learning_rate": 8.97342655819303e-06,
733
+ "loss": 0.0992,
734
  "step": 242
735
  },
736
  {
737
  "epoch": 0.7,
738
+ "learning_rate": 8.673118842628595e-06,
739
+ "loss": 0.0892,
740
  "step": 244
741
  },
742
  {
743
  "epoch": 0.7,
744
+ "learning_rate": 8.376524762481069e-06,
745
+ "loss": 0.0975,
746
  "step": 246
747
  },
748
  {
749
  "epoch": 0.71,
750
+ "learning_rate": 8.083741559194515e-06,
751
+ "loss": 0.0982,
752
  "step": 248
753
  },
754
  {
755
  "epoch": 0.71,
756
+ "learning_rate": 7.794865224777504e-06,
757
+ "loss": 0.1026,
758
  "step": 250
759
  },
760
  {
761
  "epoch": 0.72,
762
+ "learning_rate": 7.509990470331159e-06,
763
+ "loss": 0.0973,
764
  "step": 252
765
  },
766
  {
767
  "epoch": 0.72,
768
+ "learning_rate": 7.229210694997113e-06,
769
+ "loss": 0.0985,
770
  "step": 254
771
  },
772
  {
773
  "epoch": 0.73,
774
+ "learning_rate": 6.952617955335641e-06,
775
+ "loss": 0.1005,
776
  "step": 256
777
  },
778
  {
779
  "epoch": 0.74,
780
+ "learning_rate": 6.680302935143963e-06,
781
+ "loss": 0.0968,
782
  "step": 258
783
  },
784
  {
785
  "epoch": 0.74,
786
+ "learning_rate": 6.412354915724642e-06,
787
+ "loss": 0.1079,
788
  "step": 260
789
  },
790
  {
791
  "epoch": 0.75,
792
+ "learning_rate": 6.14886174661373e-06,
793
+ "loss": 0.0994,
794
  "step": 262
795
  },
796
  {
797
  "epoch": 0.75,
798
+ "learning_rate": 5.889909816778458e-06,
799
+ "loss": 0.0991,
800
  "step": 264
801
  },
802
  {
803
  "epoch": 0.76,
804
+ "learning_rate": 5.635584026293655e-06,
805
+ "loss": 0.098,
806
  "step": 266
807
  },
808
  {
809
  "epoch": 0.76,
810
+ "learning_rate": 5.385967758506407e-06,
811
+ "loss": 0.1035,
812
  "step": 268
813
  },
814
  {
815
  "epoch": 0.77,
816
+ "learning_rate": 5.141142852697956e-06,
817
+ "loss": 0.1022,
818
  "step": 270
819
  },
820
  {
821
  "epoch": 0.77,
822
+ "learning_rate": 4.901189577251864e-06,
823
+ "loss": 0.0938,
824
  "step": 272
825
  },
826
  {
827
  "epoch": 0.78,
828
+ "learning_rate": 4.6661866033371506e-06,
829
+ "loss": 0.1026,
830
  "step": 274
831
  },
832
  {
833
  "epoch": 0.79,
834
+ "learning_rate": 4.4362109791151695e-06,
835
+ "loss": 0.0983,
836
  "step": 276
837
  },
838
  {
839
  "epoch": 0.79,
840
+ "learning_rate": 4.211338104478548e-06,
841
+ "loss": 0.1036,
842
  "step": 278
843
  },
844
  {
845
  "epoch": 0.8,
846
+ "learning_rate": 3.991641706330575e-06,
847
+ "loss": 0.092,
848
  "step": 280
849
  },
850
  {
851
  "epoch": 0.8,
852
+ "learning_rate": 3.777193814413045e-06,
853
+ "loss": 0.1038,
854
  "step": 282
855
  },
856
  {
857
  "epoch": 0.81,
858
+ "learning_rate": 3.5680647376905666e-06,
859
+ "loss": 0.1,
860
  "step": 284
861
  },
862
  {
863
  "epoch": 0.81,
864
+ "learning_rate": 3.3643230412990625e-06,
865
+ "loss": 0.1053,
866
  "step": 286
867
  },
868
  {
869
  "epoch": 0.82,
870
+ "learning_rate": 3.1660355240659423e-06,
871
+ "loss": 0.1029,
872
  "step": 288
873
  },
874
  {
875
  "epoch": 0.83,
876
+ "learning_rate": 2.973267196609453e-06,
877
+ "loss": 0.0974,
878
  "step": 290
879
  },
880
  {
881
  "epoch": 0.83,
882
+ "learning_rate": 2.786081260024236e-06,
883
+ "loss": 0.0966,
884
  "step": 292
885
  },
886
  {
887
  "epoch": 0.84,
888
+ "learning_rate": 2.604539085160218e-06,
889
+ "loss": 0.1029,
890
  "step": 294
891
  },
892
  {
893
  "epoch": 0.84,
894
+ "learning_rate": 2.428700192501534e-06,
895
+ "loss": 0.1031,
896
  "step": 296
897
  },
898
  {
899
  "epoch": 0.85,
900
+ "learning_rate": 2.2586222326521277e-06,
901
+ "loss": 0.1017,
902
  "step": 298
903
  },
904
  {
905
  "epoch": 0.85,
906
+ "learning_rate": 2.0943609674343833e-06,
907
+ "loss": 0.1021,
908
  "step": 300
909
  },
910
  {
911
  "epoch": 0.86,
912
+ "learning_rate": 1.9359702516070553e-06,
913
+ "loss": 0.103,
914
  "step": 302
915
  },
916
  {
917
  "epoch": 0.87,
918
+ "learning_rate": 1.7835020152084116e-06,
919
+ "loss": 0.1023,
920
  "step": 304
921
  },
922
  {
923
  "epoch": 0.87,
924
+ "learning_rate": 1.6370062465304503e-06,
925
+ "loss": 0.1034,
926
  "step": 306
927
  },
928
  {
929
  "epoch": 0.88,
930
+ "learning_rate": 1.496530975729693e-06,
931
+ "loss": 0.0926,
932
  "step": 308
933
  },
934
  {
935
  "epoch": 0.88,
936
+ "learning_rate": 1.3621222590800342e-06,
937
+ "loss": 0.099,
938
  "step": 310
939
  },
940
  {
941
  "epoch": 0.89,
942
+ "learning_rate": 1.2338241638726811e-06,
943
+ "loss": 0.0985,
944
  "step": 312
945
  },
946
  {
947
  "epoch": 0.89,
948
+ "learning_rate": 1.1116787539682571e-06,
949
+ "loss": 0.1044,
950
  "step": 314
951
  },
952
  {
953
  "epoch": 0.9,
954
+ "learning_rate": 9.957260760057164e-07,
955
+ "loss": 0.0929,
956
  "step": 316
957
  },
958
  {
959
  "epoch": 0.91,
960
+ "learning_rate": 8.860041462726543e-07,
961
+ "loss": 0.0988,
962
  "step": 318
963
  },
964
  {
965
  "epoch": 0.91,
966
+ "learning_rate": 7.825489382412521e-07,
967
+ "loss": 0.1043,
968
  "step": 320
969
  },
970
  {
971
  "epoch": 0.92,
972
+ "learning_rate": 6.853943707740218e-07,
973
+ "loss": 0.0975,
974
  "step": 322
975
  },
976
  {
977
  "epoch": 0.92,
978
+ "learning_rate": 5.945722970031332e-07,
979
+ "loss": 0.1065,
980
  "step": 324
981
  },
982
  {
983
  "epoch": 0.93,
984
+ "learning_rate": 5.101124938870605e-07,
985
+ "loss": 0.0978,
986
  "step": 326
987
  },
988
  {
989
  "epoch": 0.93,
990
+ "learning_rate": 4.320426524478749e-07,
991
+ "loss": 0.1,
992
  "step": 328
993
  },
994
  {
995
  "epoch": 0.94,
996
+ "learning_rate": 3.603883686924681e-07,
997
+ "loss": 0.0967,
998
  "step": 330
999
  },
1000
  {
1001
  "epoch": 0.95,
1002
+ "learning_rate": 2.951731352206322e-07,
1003
+ "loss": 0.0954,
1004
  "step": 332
1005
  },
1006
  {
1007
  "epoch": 0.95,
1008
+ "learning_rate": 2.3641833352276768e-07,
1009
+ "loss": 0.1016,
1010
  "step": 334
1011
  },
1012
  {
1013
  "epoch": 0.96,
1014
+ "learning_rate": 1.841432269697463e-07,
1015
+ "loss": 0.0993,
1016
  "step": 336
1017
  },
1018
  {
1019
  "epoch": 0.96,
1020
+ "learning_rate": 1.3836495449719878e-07,
1021
+ "loss": 0.1002,
1022
  "step": 338
1023
  },
1024
  {
1025
  "epoch": 0.97,
1026
+ "learning_rate": 9.90985249863563e-08,
1027
+ "loss": 0.1043,
1028
  "step": 340
1029
  },
1030
  {
1031
  "epoch": 0.97,
1032
+ "learning_rate": 6.635681234321789e-08,
1033
+ "loss": 0.0984,
1034
  "step": 342
1035
  },
1036
  {
1037
  "epoch": 0.98,
1038
+ "learning_rate": 4.0150551277724494e-08,
1039
+ "loss": 0.1046,
1040
  "step": 344
1041
  },
1042
  {
1043
  "epoch": 0.99,
1044
+ "learning_rate": 2.0488333784249858e-08,
1045
+ "loss": 0.0948,
1046
  "step": 346
1047
  },
1048
  {
1049
  "epoch": 0.99,
1050
+ "learning_rate": 7.376606324644986e-09,
1051
+ "loss": 0.0897,
1052
  "step": 348
1053
  },
1054
  {
1055
  "epoch": 1.0,
1056
+ "learning_rate": 8.196677146932175e-10,
1057
+ "loss": 0.0971,
1058
  "step": 350
 
 
 
 
 
 
 
 
 
1059
  }
1060
  ],
1061
  "max_steps": 351,
1062
  "num_train_epochs": 1,
1063
+ "total_flos": 2.454170351173632e+17,
1064
  "trial_name": null,
1065
  "trial_params": null
1066
  }
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:64b27d88bdfcbf9a4147d42a9156ccaa0672586f3f9f8e75dd07035a21609949
3
- size 4463
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:773009bb9ae9edbbd62a95f7459bff6131e643ef55e6dc91013d04842d7343bb
3
+ size 4847
zero_to_fp32.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
4
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
5
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
6
+ # application.
7
+ #
8
+ # example: python zero_to_fp32.py . pytorch_model.bin
9
+
10
+ import argparse
11
+ import torch
12
+ import glob
13
+ import math
14
+ import os
15
+ import re
16
+ from collections import OrderedDict
17
+
18
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
19
+ # DeepSpeed data structures it has to be available in the current python environment.
20
+ from deepspeed.utils import logger
21
+ from deepspeed.checkpoint.constants import (DS_VERSION,
22
+ OPTIMIZER_STATE_DICT,
23
+ SINGLE_PARTITION_OF_FP32_GROUPS,
24
+ FP32_FLAT_GROUPS,
25
+ ZERO_STAGE,
26
+ PARTITION_COUNT,
27
+ PARAM_SHAPES,
28
+ BUFFER_NAMES)
29
+
30
+ debug = 0
31
+
32
+ # load to cpu
33
+ device = torch.device('cpu')
34
+
35
+
36
+ def atoi(text):
37
+ return int(text) if text.isdigit() else text
38
+
39
+
40
+ def natural_keys(text):
41
+ '''
42
+ alist.sort(key=natural_keys) sorts in human order
43
+ http://nedbatchelder.com/blog/200712/human_sorting.html
44
+ (See Toothy's implementation in the comments)
45
+ '''
46
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
47
+
48
+
49
+ def get_model_state_file(checkpoint_dir, zero_stage):
50
+ if not os.path.isdir(checkpoint_dir):
51
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
52
+
53
+ # there should be only one file
54
+ if zero_stage == 2:
55
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
56
+ elif zero_stage == 3:
57
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
58
+
59
+ if not os.path.exists(file):
60
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
61
+
62
+ return file
63
+
64
+
65
+ def get_optim_files(checkpoint_dir):
66
+ # XXX: need to test that this simple glob rule works for multi-node setup too
67
+ optim_files = sorted(glob.glob(os.path.join(checkpoint_dir,
68
+ "*_optim_states.pt")),
69
+ key=natural_keys)
70
+
71
+ if len(optim_files) == 0:
72
+ raise FileNotFoundError(
73
+ f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
74
+
75
+ return optim_files
76
+
77
+
78
+ def parse_model_state(file):
79
+ state_dict = torch.load(file, map_location=device)
80
+
81
+ if BUFFER_NAMES not in state_dict:
82
+ raise ValueError(f"{file} is not a model state checkpoint")
83
+ buffer_names = state_dict[BUFFER_NAMES]
84
+ if debug:
85
+ print("Found buffers:", buffer_names)
86
+
87
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
88
+ buffers = {
89
+ k: v.float()
90
+ for k,
91
+ v in state_dict["module"].items() if k in buffer_names
92
+ }
93
+ param_shapes = state_dict[PARAM_SHAPES]
94
+
95
+ ds_version = state_dict.get(DS_VERSION, None)
96
+
97
+ return buffers, param_shapes, ds_version
98
+
99
+
100
+ def parse_optim_states(files, ds_checkpoint_dir):
101
+
102
+ total_files = len(files)
103
+ state_dicts = []
104
+ for f in files:
105
+ state_dicts.append(torch.load(f, map_location=device))
106
+
107
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
108
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
109
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
110
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
111
+
112
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
113
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
114
+ # use the max of the partition_count to get the dp world_size.
115
+
116
+ if type(world_size) is list:
117
+ world_size = max(world_size)
118
+
119
+ if world_size != total_files:
120
+ raise ValueError(
121
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
122
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
123
+ )
124
+
125
+ # the groups are named differently in each stage
126
+ if zero_stage == 2:
127
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
128
+ elif zero_stage == 3:
129
+ fp32_groups_key = FP32_FLAT_GROUPS
130
+ else:
131
+ raise ValueError(f"unknown zero stage {zero_stage}")
132
+
133
+ if zero_stage == 2:
134
+ fp32_flat_groups = [
135
+ state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key]
136
+ for i in range(len(state_dicts))
137
+ ]
138
+ elif zero_stage == 3:
139
+ # if there is more than one param group, there will be multiple flattened tensors - one
140
+ # flattened tensor per group - for simplicity merge them into a single tensor
141
+ #
142
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
143
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
144
+
145
+ fp32_flat_groups = [
146
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key],
147
+ 0) for i in range(len(state_dicts))
148
+ ]
149
+
150
+ return zero_stage, world_size, fp32_flat_groups
151
+
152
+
153
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
154
+ """
155
+ Returns fp32 state_dict reconstructed from ds checkpoint
156
+
157
+ Args:
158
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
159
+
160
+ """
161
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
162
+
163
+ optim_files = get_optim_files(ds_checkpoint_dir)
164
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
165
+ print(
166
+ f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
167
+
168
+ model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
169
+ buffers, param_shapes, ds_version = parse_model_state(model_file)
170
+ print(f'Parsing checkpoint created by deepspeed=={ds_version}')
171
+
172
+ if zero_stage == 2:
173
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
174
+ param_shapes,
175
+ fp32_flat_groups,
176
+ buffers)
177
+ elif zero_stage == 3:
178
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
179
+ param_shapes,
180
+ fp32_flat_groups,
181
+ buffers)
182
+
183
+
184
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
185
+ param_shapes,
186
+ fp32_flat_groups,
187
+ buffers):
188
+
189
+ # Reconstruction protocol:
190
+ #
191
+ # XXX: document this
192
+
193
+ if debug:
194
+ for i in range(world_size):
195
+ for j in range(len(fp32_flat_groups[0])):
196
+ print(
197
+ f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
198
+
199
+ # XXX: memory usage doubles here (zero2)
200
+ num_param_groups = len(fp32_flat_groups[0])
201
+ merged_single_partition_of_fp32_groups = []
202
+ for i in range(num_param_groups):
203
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
204
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
205
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
206
+ avail_numel = sum([
207
+ full_single_fp32_vector.numel()
208
+ for full_single_fp32_vector in merged_single_partition_of_fp32_groups
209
+ ])
210
+
211
+ if debug:
212
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
213
+ wanted_numel = sum(
214
+ [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
215
+ # not asserting if there is a mismatch due to possible padding
216
+ print(f"Have {avail_numel} numels to process.")
217
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
218
+
219
+ state_dict = OrderedDict()
220
+
221
+ # buffers
222
+ state_dict.update(buffers)
223
+ if debug:
224
+ print(f"added {len(buffers)} buffers")
225
+
226
+ # params
227
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
228
+ # out-of-core computing solution
229
+ total_numel = 0
230
+ total_params = 0
231
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
232
+ offset = 0
233
+ avail_numel = full_single_fp32_vector.numel()
234
+ for name, shape in shapes.items():
235
+
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+ total_params += 1
239
+
240
+ if debug:
241
+ print(
242
+ f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
243
+ )
244
+ state_dict[name] = full_single_fp32_vector.narrow(
245
+ 0,
246
+ offset,
247
+ unpartitioned_numel).view(shape)
248
+ offset += unpartitioned_numel
249
+
250
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
251
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
252
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
253
+ # live optimizer object, so we are checking that the numbers are within the right range
254
+ align_to = 2 * world_size
255
+
256
+ def zero2_align(x):
257
+ return align_to * math.ceil(x / align_to)
258
+
259
+ if debug:
260
+ print(f"original offset={offset}, avail_numel={avail_numel}")
261
+
262
+ offset = zero2_align(offset)
263
+ avail_numel = zero2_align(avail_numel)
264
+
265
+ if debug:
266
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
267
+
268
+ # Sanity check
269
+ if offset != avail_numel:
270
+ raise ValueError(
271
+ f"consumed {offset} numels out of {avail_numel} - something is wrong")
272
+
273
+ print(
274
+ f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
275
+ )
276
+
277
+ return state_dict
278
+
279
+
280
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
281
+ remainder = unpartitioned_numel % world_size
282
+ padding_numel = (world_size - remainder) if remainder else 0
283
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
284
+ return partitioned_numel, padding_numel
285
+
286
+
287
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
288
+ param_shapes,
289
+ fp32_flat_groups,
290
+ buffers):
291
+
292
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
293
+ # param, re-consolidating each param, while dealing with padding if any
294
+
295
+ avail_numel = fp32_flat_groups[0].numel() * world_size
296
+ # merge list of dicts, preserving order
297
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
298
+
299
+ if debug:
300
+ for i in range(world_size):
301
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
302
+
303
+ wanted_params = len(param_shapes)
304
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
305
+ # not asserting if there is a mismatch due to possible padding
306
+ print(f"Have {avail_numel} numels to process.")
307
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
308
+
309
+ state_dict = OrderedDict()
310
+
311
+ # buffers
312
+ state_dict.update(buffers)
313
+ if debug:
314
+ print(f"added {len(buffers)} buffers")
315
+
316
+ # params
317
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
318
+ # out-of-core computing solution
319
+ offset = 0
320
+ total_numel = 0
321
+ total_params = 0
322
+ for name, shape in param_shapes.items():
323
+
324
+ unpartitioned_numel = shape.numel()
325
+ total_numel += unpartitioned_numel
326
+ total_params += 1
327
+
328
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
329
+
330
+ if debug:
331
+ print(
332
+ f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
333
+ )
334
+
335
+ # XXX: memory usage doubles here
336
+ state_dict[name] = torch.cat(
337
+ tuple(fp32_flat_groups[i].narrow(0,
338
+ offset,
339
+ partitioned_numel)
340
+ for i in range(world_size)),
341
+ 0).narrow(0,
342
+ 0,
343
+ unpartitioned_numel).view(shape)
344
+ offset += partitioned_numel
345
+
346
+ offset *= world_size
347
+
348
+ # Sanity check
349
+ if offset != avail_numel:
350
+ raise ValueError(
351
+ f"consumed {offset} numels out of {avail_numel} - something is wrong")
352
+
353
+ print(
354
+ f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
355
+ )
356
+
357
+ return state_dict
358
+
359
+
360
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
361
+ """
362
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
363
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
364
+ via a model hub.
365
+
366
+ Args:
367
+ - ``checkpoint_dir``: path to the desired checkpoint folder
368
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
369
+
370
+ Returns:
371
+ - pytorch ``state_dict``
372
+
373
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
374
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
375
+ the checkpoint.
376
+
377
+ A typical usage might be ::
378
+
379
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
380
+ # do the training and checkpoint saving
381
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
382
+ model = model.cpu() # move to cpu
383
+ model.load_state_dict(state_dict)
384
+ # submit to model hub or save the model to share with others
385
+
386
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
387
+ application. i.e. you will need to re-initialize the deepspeed engine, since
388
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
389
+
390
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
391
+
392
+ """
393
+ if tag is None:
394
+ latest_path = os.path.join(checkpoint_dir, 'latest')
395
+ if os.path.isfile(latest_path):
396
+ with open(latest_path, 'r') as fd:
397
+ tag = fd.read().strip()
398
+ else:
399
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
400
+
401
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
402
+
403
+ if not os.path.isdir(ds_checkpoint_dir):
404
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
405
+
406
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
407
+
408
+
409
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
410
+ """
411
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
412
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
413
+
414
+ Args:
415
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
416
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
417
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
418
+ """
419
+
420
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
421
+ print(f"Saving fp32 state dict to {output_file}")
422
+ torch.save(state_dict, output_file)
423
+
424
+
425
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
426
+ """
427
+ 1. Put the provided model to cpu
428
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
429
+ 3. Load it into the provided model
430
+
431
+ Args:
432
+ - ``model``: the model object to update
433
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
434
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
435
+
436
+ Returns:
437
+ - ``model`: modified model
438
+
439
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
440
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
441
+ conveniently placed for you in the checkpoint folder.
442
+
443
+ A typical usage might be ::
444
+
445
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
446
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
447
+ # submit to model hub or save the model to share with others
448
+
449
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
450
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
451
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
452
+
453
+ """
454
+ logger.info(f"Extracting fp32 weights")
455
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
456
+
457
+ logger.info(f"Overwriting model with fp32 weights")
458
+ model = model.cpu()
459
+ model.load_state_dict(state_dict, strict=False)
460
+
461
+ return model
462
+
463
+
464
+ if __name__ == "__main__":
465
+
466
+ parser = argparse.ArgumentParser()
467
+ parser.add_argument(
468
+ "checkpoint_dir",
469
+ type=str,
470
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
471
+ parser.add_argument(
472
+ "output_file",
473
+ type=str,
474
+ help=
475
+ "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
476
+ )
477
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
478
+ args = parser.parse_args()
479
+
480
+ debug = args.debug
481
+
482
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)