Gunulhona commited on
Commit
20ed20e
1 Parent(s): 64c1d54

Upload folder using huggingface_hub

Browse files
checkpoint/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44c8c07c13cb26ef988277d0c4930151d0bfbb5299f898d82b5dc1d05d5eb921
3
+ size 6720198532
checkpoint/zero_pp_rank_0_mp_rank_00_optim_states.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e13d3dbd2caf27a59f05d11ea06ae24dbd0ba7f8c74db9e74d2db7431297ab6b
3
- size 27606881822
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f928d931609ef10ba41e20d9e37cdd8e737c1debad9dbe9d1455bdb1cbdbad9
3
+ size 40317564505
zero_to_fp32.py CHANGED
@@ -5,7 +5,7 @@
5
 
6
  # DeepSpeed Team
7
 
8
- # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
9
  # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
  # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
  # application.
@@ -63,7 +63,7 @@ def get_model_state_file(checkpoint_dir, zero_stage):
63
  raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
64
 
65
  # there should be only one file
66
- if zero_stage == 2:
67
  file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
68
  elif zero_stage == 3:
69
  file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
@@ -143,7 +143,11 @@ def parse_optim_states(files, ds_checkpoint_dir):
143
  total_files = len(files)
144
  state_dicts = []
145
  for f in files:
146
- state_dicts.append(torch.load(f, map_location=device))
 
 
 
 
147
 
148
  if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
149
  raise ValueError(f"{files[0]} is not a zero checkpoint")
@@ -164,14 +168,14 @@ def parse_optim_states(files, ds_checkpoint_dir):
164
  )
165
 
166
  # the groups are named differently in each stage
167
- if zero_stage == 2:
168
  fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
169
  elif zero_stage == 3:
170
  fp32_groups_key = FP32_FLAT_GROUPS
171
  else:
172
  raise ValueError(f"unknown zero stage {zero_stage}")
173
 
174
- if zero_stage == 2:
175
  fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
176
  elif zero_stage == 3:
177
  # if there is more than one param group, there will be multiple flattened tensors - one
@@ -206,7 +210,7 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
206
  zero_model_states = parse_model_states(model_files)
207
  print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
208
 
209
- if zero_stage == 2:
210
  return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
211
  elif zero_stage == 3:
212
  return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
@@ -244,6 +248,11 @@ def _zero2_merge_frozen_params(state_dict, zero_model_states):
244
  print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
 
246
 
 
 
 
 
 
247
  def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
248
  param_shapes = zero_model_states[0].param_shapes
249
 
@@ -283,7 +292,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
283
  avail_numel = full_single_fp32_vector.numel()
284
  for name, shape in shapes.items():
285
 
286
- unpartitioned_numel = shape.numel()
287
  total_numel += unpartitioned_numel
288
  total_params += 1
289
 
@@ -570,9 +579,14 @@ if __name__ == "__main__":
570
  "output_file",
571
  type=str,
572
  help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
 
 
 
 
 
573
  parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
574
  args = parser.parse_args()
575
 
576
  debug = args.debug
577
 
578
- convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
 
5
 
6
  # DeepSpeed Team
7
 
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
  # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
  # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
  # application.
 
63
  raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
64
 
65
  # there should be only one file
66
+ if zero_stage <= 2:
67
  file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
68
  elif zero_stage == 3:
69
  file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
 
143
  total_files = len(files)
144
  state_dicts = []
145
  for f in files:
146
+ state_dict = torch.load(f, map_location=device)
147
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
148
+ # and also handle the case where it was already removed by another helper script
149
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
150
+ state_dicts.append(state_dict)
151
 
152
  if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
153
  raise ValueError(f"{files[0]} is not a zero checkpoint")
 
168
  )
169
 
170
  # the groups are named differently in each stage
171
+ if zero_stage <= 2:
172
  fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
173
  elif zero_stage == 3:
174
  fp32_groups_key = FP32_FLAT_GROUPS
175
  else:
176
  raise ValueError(f"unknown zero stage {zero_stage}")
177
 
178
+ if zero_stage <= 2:
179
  fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
180
  elif zero_stage == 3:
181
  # if there is more than one param group, there will be multiple flattened tensors - one
 
210
  zero_model_states = parse_model_states(model_files)
211
  print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
 
213
+ if zero_stage <= 2:
214
  return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
215
  elif zero_stage == 3:
216
  return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
 
248
  print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
249
 
250
 
251
+ def _has_callable(obj, fn):
252
+ attr = getattr(obj, fn, None)
253
+ return callable(attr)
254
+
255
+
256
  def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
257
  param_shapes = zero_model_states[0].param_shapes
258
 
 
292
  avail_numel = full_single_fp32_vector.numel()
293
  for name, shape in shapes.items():
294
 
295
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
296
  total_numel += unpartitioned_numel
297
  total_params += 1
298
 
 
579
  "output_file",
580
  type=str,
581
  help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
582
+ parser.add_argument("-t",
583
+ "--tag",
584
+ type=str,
585
+ default=None,
586
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
587
  parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
588
  args = parser.parse_args()
589
 
590
  debug = args.debug
591
 
592
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)