pszemraj commited on
Commit
0afa3fd
1 Parent(s): 467a60b

End of training

Browse files
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - generated_from_trainer
5
+ model-index:
6
+ - name: checkpoints
7
+ results: []
8
+ ---
9
+
10
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
11
+ should probably proofread and complete it, then remove this comment. -->
12
+
13
+ # checkpoints
14
+
15
+ This model is a fine-tuned version of [distilgpt2](https://huggingface.co/distilgpt2) on an unknown dataset.
16
+ It achieves the following results on the evaluation set:
17
+ - Loss: 2.2461
18
+
19
+ ## Model description
20
+
21
+ More information needed
22
+
23
+ ## Intended uses & limitations
24
+
25
+ More information needed
26
+
27
+ ## Training and evaluation data
28
+
29
+ More information needed
30
+
31
+ ## Training procedure
32
+
33
+ ### Training hyperparameters
34
+
35
+ The following hyperparameters were used during training:
36
+ - learning_rate: 2e-05
37
+ - train_batch_size: 32
38
+ - eval_batch_size: 32
39
+ - seed: 42
40
+ - distributed_type: multi-GPU
41
+ - gradient_accumulation_steps: 4
42
+ - total_train_batch_size: 128
43
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
44
+ - lr_scheduler_type: cosine
45
+ - lr_scheduler_warmup_ratio: 0.05
46
+ - num_epochs: 30
47
+
48
+ ### Training results
49
+
50
+ | Training Loss | Epoch | Step | Validation Loss |
51
+ |:-------------:|:-----:|:-----:|:---------------:|
52
+ | No log | 1.0 | 418 | 2.7793 |
53
+ | 2.9952 | 2.0 | 836 | 2.6914 |
54
+ | 2.7684 | 3.0 | 1254 | 2.6348 |
55
+ | 2.685 | 4.0 | 1672 | 2.5938 |
56
+ | 2.6243 | 5.0 | 2090 | 2.5625 |
57
+ | 2.5816 | 6.0 | 2508 | 2.5332 |
58
+ | 2.5816 | 7.0 | 2926 | 2.5098 |
59
+ | 2.545 | 8.0 | 3344 | 2.4902 |
60
+ | 2.5083 | 9.0 | 3762 | 2.4707 |
61
+ | 2.4793 | 10.0 | 4180 | 2.4551 |
62
+ | 2.4531 | 11.0 | 4598 | 2.4395 |
63
+ | 2.4269 | 12.0 | 5016 | 2.4238 |
64
+ | 2.4269 | 13.0 | 5434 | 2.4102 |
65
+ | 2.4051 | 14.0 | 5852 | 2.3945 |
66
+ | 2.3777 | 15.0 | 6270 | 2.3848 |
67
+ | 2.3603 | 16.0 | 6688 | 2.3711 |
68
+ | 2.3394 | 17.0 | 7106 | 2.3613 |
69
+ | 2.3206 | 18.0 | 7524 | 2.3516 |
70
+ | 2.3206 | 19.0 | 7942 | 2.3398 |
71
+ | 2.3026 | 20.0 | 8360 | 2.3301 |
72
+ | 2.2823 | 21.0 | 8778 | 2.3203 |
73
+ | 2.2669 | 22.0 | 9196 | 2.3105 |
74
+ | 2.2493 | 23.0 | 9614 | 2.3027 |
75
+ | 2.2334 | 24.0 | 10032 | 2.2930 |
76
+ | 2.2334 | 25.0 | 10450 | 2.2852 |
77
+ | 2.2194 | 26.0 | 10868 | 2.2754 |
78
+ | 2.2014 | 27.0 | 11286 | 2.2695 |
79
+ | 2.1868 | 28.0 | 11704 | 2.2598 |
80
+ | 2.171 | 29.0 | 12122 | 2.2539 |
81
+ | 2.1597 | 30.0 | 12540 | 2.2461 |
82
+
83
+
84
+ ### Framework versions
85
+
86
+ - Transformers 4.16.1
87
+ - Pytorch 1.10.0+cu111
88
+ - Tokenizers 0.11.0
last-checkpoint/config.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "_name_or_path": "distilgpt2",
3
- "_num_labels": 1,
4
- "activation_function": "gelu_new",
5
- "architectures": [
6
- "GPT2LMHeadModel"
7
- ],
8
- "attn_pdrop": 0.1,
9
- "bos_token_id": 50256,
10
- "embd_pdrop": 0.1,
11
- "eos_token_id": 50256,
12
- "id2label": {
13
- "0": "LABEL_0"
14
- },
15
- "initializer_range": 0.02,
16
- "label2id": {
17
- "LABEL_0": 0
18
- },
19
- "layer_norm_epsilon": 1e-05,
20
- "model_type": "gpt2",
21
- "n_ctx": 1024,
22
- "n_embd": 768,
23
- "n_head": 12,
24
- "n_inner": null,
25
- "n_layer": 6,
26
- "n_positions": 1024,
27
- "reorder_and_upcast_attn": false,
28
- "resid_pdrop": 0.1,
29
- "scale_attn_by_inverse_layer_idx": false,
30
- "scale_attn_weights": true,
31
- "summary_activation": null,
32
- "summary_first_dropout": 0.1,
33
- "summary_proj_to_labels": true,
34
- "summary_type": "cls_index",
35
- "summary_use_proj": true,
36
- "task_specific_params": {
37
- "text-generation": {
38
- "do_sample": true,
39
- "max_length": 50
40
- }
41
- },
42
- "torch_dtype": "float16",
43
- "transformers_version": "4.16.1",
44
- "use_cache": false,
45
- "vocab_size": 50257
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/global_step418/mp_rank_00_model_states.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e01de6ce43a3831ba72c1ac13b04f60cb557bdcc5129d42bcb95d5dc78edd81
3
- size 176426750
 
 
 
 
last-checkpoint/global_step418/zero_pp_rank_0_mp_rank_00_optim_states.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:36f6013451b3047bcc0050270d1638ebab4112c301a6abd3aaf403e82e58bbd9
3
- size 982958179
 
 
 
 
last-checkpoint/latest DELETED
@@ -1 +0,0 @@
1
- global_step418
 
 
last-checkpoint/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4158ac6be6eb8a3c9e54b9d4aa678b213d42262a515e563e3f41306e60c4357
3
- size 176424894
 
 
 
 
last-checkpoint/rng_state_0.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9718ad0bd253e9900a6523cdece06f611939230f60fe782f2a9cceaccc19132b
3
- size 14503
 
 
 
 
last-checkpoint/trainer_state.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "best_metric": null,
3
- "best_model_checkpoint": null,
4
- "epoch": 0.998805256869773,
5
- "global_step": 418,
6
- "is_hyper_param_search": false,
7
- "is_local_process_zero": true,
8
- "is_world_process_zero": true,
9
- "log_history": [
10
- {
11
- "epoch": 1.0,
12
- "eval_loss": 2.779296875,
13
- "eval_runtime": 13.8161,
14
- "eval_samples_per_second": 1296.528,
15
- "eval_steps_per_second": 40.532,
16
- "step": 418
17
- }
18
- ],
19
- "max_steps": 12540,
20
- "num_train_epochs": 30,
21
- "total_flos": 1749218444181504.0,
22
- "trial_name": null,
23
- "trial_params": null
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:83d8354bcc1c4902bf8b8dbff4bd5a46adcf8432ce6e93a791b460f3959caf98
3
- size 4143
 
 
 
 
last-checkpoint/zero_to_fp32.py DELETED
@@ -1,453 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
4
- # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
5
- # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
6
- # application.
7
- #
8
- # example: python zero_to_fp32.py . pytorch_model.bin
9
-
10
- import argparse
11
- import torch
12
- import glob
13
- import math
14
- import os
15
- from collections import OrderedDict
16
-
17
- # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
18
- # DeepSpeed data structures it has to be available in the current python environment.
19
- import deepspeed
20
- from deepspeed.utils import logger
21
-
22
- debug = 0
23
-
24
- # load to cpu
25
- device = torch.device('cpu')
26
-
27
-
28
- def get_model_state_file(checkpoint_dir, zero_stage):
29
- if not os.path.isdir(checkpoint_dir):
30
- raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
31
-
32
- # there should be only one file
33
- if zero_stage == 2:
34
- file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
35
- elif zero_stage == 3:
36
- file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
37
-
38
- if not os.path.exists(file):
39
- raise FileNotFoundError(f"can't find model states file at '{file}'")
40
-
41
- return file
42
-
43
-
44
- def get_optim_files(checkpoint_dir):
45
- # XXX: need to test that this simple glob rule works for multi-node setup too
46
- optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")))
47
-
48
- if len(optim_files) == 0:
49
- raise FileNotFoundError(
50
- f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
51
-
52
- return optim_files
53
-
54
-
55
- def parse_model_state(file):
56
- state_dict = torch.load(file, map_location=device)
57
-
58
- if "buffer_names" not in state_dict:
59
- raise ValueError(f"{file} is not a model state checkpoint")
60
- buffer_names = state_dict["buffer_names"]
61
- if debug:
62
- print("Found buffers:", buffer_names)
63
-
64
- # recover just the buffers while restoring them to fp32 if they were saved in fp16
65
- buffers = {
66
- k: v.float()
67
- for k,
68
- v in state_dict["module"].items() if k in buffer_names
69
- }
70
- return buffers
71
-
72
-
73
- def parse_optim_states(files, ds_checkpoint_dir):
74
-
75
- total_files = len(files)
76
- state_dicts = []
77
- for f in files:
78
- state_dicts.append(torch.load(f, map_location=device))
79
-
80
- if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
81
- raise ValueError(f"{files[0]} is not a zero checkpoint")
82
- zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
83
- world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
84
- param_shapes = state_dicts[0]["param_shapes"]
85
- # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
86
- # parameters can be different from data parallelism for non-expert parameters. So we can just
87
- # use the max of the partition_count to get the dp world_size.
88
-
89
- if type(world_size) is list:
90
- world_size = max(world_size)
91
-
92
- if world_size != total_files:
93
- raise ValueError(
94
- f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
95
- "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
96
- )
97
-
98
- # the groups are named differently in each stage
99
- if zero_stage == 2:
100
- fp32_groups_key = "single_partition_of_fp32_groups"
101
- elif zero_stage == 3:
102
- fp32_groups_key = "fp32_flat_groups"
103
- else:
104
- raise ValueError(f"unknown zero stage {zero_stage}")
105
-
106
- if zero_stage == 2:
107
- fp32_flat_groups = [
108
- state_dicts[i]['optimizer_state_dict'][fp32_groups_key]
109
- for i in range(len(state_dicts))
110
- ]
111
- elif zero_stage == 3:
112
- # if there is more than one param group, there will be multiple flattened tensors - one
113
- # flattened tensor per group - for simplicity merge them into a single tensor
114
- #
115
- # XXX: could make the script more memory efficient for when there are multiple groups - it
116
- # will require matching the sub-lists of param_shapes for each param group flattened tensor
117
-
118
- fp32_flat_groups = [
119
- torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
120
- 0) for i in range(len(state_dicts))
121
- ]
122
-
123
- return zero_stage, world_size, param_shapes, fp32_flat_groups
124
-
125
-
126
- def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
127
- """
128
- Returns fp32 state_dict reconstructed from ds checkpoint
129
-
130
- Args:
131
- - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
132
-
133
- """
134
- print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
135
-
136
- optim_files = get_optim_files(ds_checkpoint_dir)
137
- zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
138
- print(
139
- f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
140
-
141
- model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
142
- buffers = parse_model_state(model_file)
143
-
144
- if zero_stage == 2:
145
- return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
146
- param_shapes,
147
- fp32_flat_groups,
148
- buffers)
149
- elif zero_stage == 3:
150
- return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
151
- param_shapes,
152
- fp32_flat_groups,
153
- buffers)
154
-
155
-
156
- def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
157
- param_shapes,
158
- fp32_flat_groups,
159
- buffers):
160
-
161
- # Reconstruction protocol:
162
- #
163
- # XXX: document this
164
-
165
- if debug:
166
- for i in range(world_size):
167
- for j in range(len(fp32_flat_groups[0])):
168
- print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
169
-
170
- # XXX: memory usage doubles here (zero2)
171
- num_param_groups = len(fp32_flat_groups[0])
172
- merged_single_partition_of_fp32_groups = []
173
- for i in range(num_param_groups):
174
- merged_partitions = [sd[i] for sd in fp32_flat_groups]
175
- full_single_fp32_vector = torch.cat(merged_partitions, 0)
176
- merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
177
- avail_numel = sum([
178
- full_single_fp32_vector.numel()
179
- for full_single_fp32_vector in merged_single_partition_of_fp32_groups
180
- ])
181
-
182
- if debug:
183
- wanted_params = sum([len(shapes) for shapes in param_shapes])
184
- wanted_numel = sum(
185
- [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
186
- # not asserting if there is a mismatch due to possible padding
187
- print(f"Have {avail_numel} numels to process.")
188
- print(f"Need {wanted_numel} numels in {wanted_params} params.")
189
-
190
- state_dict = OrderedDict()
191
-
192
- # buffers
193
- state_dict.update(buffers)
194
- if debug:
195
- print(f"added {len(buffers)} buffers")
196
-
197
- # params
198
- # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
199
- # out-of-core computing solution
200
- total_numel = 0
201
- total_params = 0
202
- for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
203
- offset = 0
204
- avail_numel = full_single_fp32_vector.numel()
205
- for name, shape in shapes.items():
206
-
207
- unpartitioned_numel = shape.numel()
208
- total_numel += unpartitioned_numel
209
- total_params += 1
210
-
211
- if debug:
212
- print(
213
- f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
214
- )
215
- state_dict[name] = full_single_fp32_vector.narrow(
216
- 0,
217
- offset,
218
- unpartitioned_numel).view(shape)
219
- offset += unpartitioned_numel
220
-
221
- # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
222
- # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
223
- # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
224
- # live optimizer object, so we are checking that the numbers are within the right range
225
- align_to = 2 * world_size
226
-
227
- def zero2_align(x):
228
- return align_to * math.ceil(x / align_to)
229
-
230
- if debug:
231
- print(f"original offset={offset}, avail_numel={avail_numel}")
232
-
233
- offset = zero2_align(offset)
234
- avail_numel = zero2_align(avail_numel)
235
-
236
- if debug:
237
- print(f"aligned offset={offset}, avail_numel={avail_numel}")
238
-
239
- # Sanity check
240
- if offset != avail_numel:
241
- raise ValueError(
242
- f"consumed {offset} numels out of {avail_numel} - something is wrong")
243
-
244
- print(
245
- f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
246
- )
247
-
248
- return state_dict
249
-
250
-
251
- def zero3_partitioned_param_info(unpartitioned_numel, world_size):
252
- remainder = unpartitioned_numel % world_size
253
- padding_numel = (world_size - remainder) if remainder else 0
254
- partitioned_numel = math.ceil(unpartitioned_numel / world_size)
255
- return partitioned_numel, padding_numel
256
-
257
-
258
- def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
259
- param_shapes,
260
- fp32_flat_groups,
261
- buffers):
262
-
263
- # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
264
- # param, re-consolidating each param, while dealing with padding if any
265
-
266
- avail_numel = fp32_flat_groups[0].numel() * world_size
267
- # merge list of dicts, preserving order
268
- param_shapes = {k: v for d in param_shapes for k, v in d.items()}
269
-
270
- if debug:
271
- for i in range(world_size):
272
- print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}")
273
-
274
- wanted_params = len(param_shapes)
275
- wanted_numel = sum(shape.numel() for shape in param_shapes.values())
276
- # not asserting if there is a mismatch due to possible padding
277
- print(f"Have {avail_numel} numels to process.")
278
- print(f"Need {wanted_numel} numels in {wanted_params} params.")
279
-
280
- state_dict = OrderedDict()
281
-
282
- # buffers
283
- state_dict.update(buffers)
284
- if debug:
285
- print(f"added {len(buffers)} buffers")
286
-
287
- # params
288
- # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
289
- # out-of-core computing solution
290
- offset = 0
291
- total_numel = 0
292
- total_params = 0
293
- for name, shape in param_shapes.items():
294
-
295
- unpartitioned_numel = shape.numel()
296
- total_numel += unpartitioned_numel
297
- total_params += 1
298
-
299
- partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
300
-
301
- if debug:
302
- print(
303
- f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
304
- )
305
-
306
- # XXX: memory usage doubles here
307
- state_dict[name] = torch.cat(
308
- tuple(fp32_flat_groups[i].narrow(0,
309
- offset,
310
- partitioned_numel)
311
- for i in range(world_size)),
312
- 0).narrow(0,
313
- 0,
314
- unpartitioned_numel).view(shape)
315
- offset += partitioned_numel
316
-
317
- offset *= world_size
318
-
319
- # Sanity check
320
- if offset != avail_numel:
321
- raise ValueError(
322
- f"consumed {offset} numels out of {avail_numel} - something is wrong")
323
-
324
- print(
325
- f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
326
- )
327
-
328
- return state_dict
329
-
330
-
331
- def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
332
- """
333
- Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
334
- ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
335
- via a model hub.
336
-
337
- Args:
338
- - ``checkpoint_dir``: path to the desired checkpoint folder
339
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
340
-
341
- Returns:
342
- - pytorch ``state_dict``
343
-
344
- Note: this approach may not work if your application doesn't have sufficient free CPU memory and
345
- you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
346
- the checkpoint.
347
-
348
- A typical usage might be ::
349
-
350
- from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
351
- # do the training and checkpoint saving
352
- state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
353
- model = model.cpu() # move to cpu
354
- model.load_state_dict(state_dict)
355
- # submit to model hub or save the model to share with others
356
-
357
- In this example the ``model`` will no longer be usable in the deepspeed context of the same
358
- application. i.e. you will need to re-initialize the deepspeed engine, since
359
- ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
360
-
361
- If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
362
-
363
- """
364
- if tag is None:
365
- latest_path = os.path.join(checkpoint_dir, 'latest')
366
- if os.path.isfile(latest_path):
367
- with open(latest_path, 'r') as fd:
368
- tag = fd.read().strip()
369
- else:
370
- raise ValueError(f"Unable to find 'latest' file at {latest_path}")
371
-
372
- ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
373
-
374
- if not os.path.isdir(ds_checkpoint_dir):
375
- raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
376
-
377
- return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
378
-
379
-
380
- def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
381
- """
382
- Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
383
- loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
384
-
385
- Args:
386
- - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
387
- - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
388
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
389
- """
390
-
391
- state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
392
- print(f"Saving fp32 state dict to {output_file}")
393
- torch.save(state_dict, output_file)
394
-
395
-
396
- def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
397
- """
398
- 1. Put the provided model to cpu
399
- 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
400
- 3. Load it into the provided model
401
-
402
- Args:
403
- - ``model``: the model object to update
404
- - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
405
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
406
-
407
- Returns:
408
- - ``model`: modified model
409
-
410
- Make sure you have plenty of CPU memory available before you call this function. If you don't
411
- have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
412
- conveniently placed for you in the checkpoint folder.
413
-
414
- A typical usage might be ::
415
-
416
- from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
417
- model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
418
- # submit to model hub or save the model to share with others
419
-
420
- Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
421
- of the same application. i.e. you will need to re-initialize the deepspeed engine, since
422
- ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
423
-
424
- """
425
- logger.info(f"Extracting fp32 weights")
426
- state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
427
-
428
- logger.info(f"Overwriting model with fp32 weights")
429
- model = model.cpu()
430
- model.load_state_dict(state_dict, strict=False)
431
-
432
- return model
433
-
434
-
435
- if __name__ == "__main__":
436
-
437
- parser = argparse.ArgumentParser()
438
- parser.add_argument(
439
- "checkpoint_dir",
440
- type=str,
441
- help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
442
- parser.add_argument(
443
- "output_file",
444
- type=str,
445
- help=
446
- "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
447
- )
448
- parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
449
- args = parser.parse_args()
450
-
451
- debug = args.debug
452
-
453
- convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4158ac6be6eb8a3c9e54b9d4aa678b213d42262a515e563e3f41306e60c4357
3
  size 176424894
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af7f9e89c13cd0bf18d5984200f49f6a1b168261eddc96ed0771abb13b96ae61
3
  size 176424894