File size: 24,177 Bytes
4109306 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 |
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py . pytorch_model.bin
import argparse
import torch
import glob
import math
import os
import re
from collections import OrderedDict
from dataclasses import dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
from deepspeed.utils import logger
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
@dataclass
class zero_model_state:
buffers: dict()
param_shapes: dict()
shared_params: list
ds_version: int
frozen_param_shapes: dict()
frozen_param_fragments: dict()
debug = 0
# load to cpu
device = torch.device('cpu')
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split(r'(\d+)', text)]
def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file
if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
if not os.path.exists(file):
raise FileNotFoundError(f"can't find model states file at '{file}'")
return file
def get_checkpoint_files(checkpoint_dir, glob_pattern):
# XXX: need to test that this simple glob rule works for multi-node setup too
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
if len(ckpt_files) == 0:
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
return ckpt_files
def get_optim_files(checkpoint_dir):
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
def get_model_state_files(checkpoint_dir):
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
def parse_model_states(files):
zero_model_states = []
for file in files:
state_dict = torch.load(file, map_location=device)
if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict[BUFFER_NAMES]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
param_shapes = state_dict[PARAM_SHAPES]
# collect parameters that are included in param_shapes
param_names = []
for s in param_shapes:
for name in s.keys():
param_names.append(name)
# update with frozen parameters
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
if frozen_param_shapes is not None:
if debug:
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
param_names += list(frozen_param_shapes.keys())
# handle shared params
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
ds_version = state_dict.get(DS_VERSION, None)
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
z_model_state = zero_model_state(buffers=buffers,
param_shapes=param_shapes,
shared_params=shared_params,
ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments)
zero_model_states.append(z_model_state)
return zero_model_states
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
state_dicts.append(state_dict)
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.
if type(world_size) is list:
world_size = max(world_size)
if world_size != total_files:
raise ValueError(
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
)
# the groups are named differently in each stage
if zero_stage <= 2:
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
elif zero_stage == 3:
fp32_groups_key = FP32_FLAT_GROUPS
else:
raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage <= 2:
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
]
return zero_stage, world_size, fp32_flat_groups
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
"""
Returns fp32 state_dict reconstructed from ds checkpoint
Args:
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
"""
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_files = get_model_state_files(ds_checkpoint_dir)
zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
def _zero2_merge_frozen_params(state_dict, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
if debug:
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
state_dict[name] = frozen_param_fragments[name]
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
# Reconstruction protocol:
#
# XXX: document this
if debug:
for i in range(world_size):
for j in range(len(fp32_flat_groups[0])):
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
# XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0])
merged_single_partition_of_fp32_groups = []
for i in range(num_param_groups):
merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum(
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
total_numel = 0
total_params = 0
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
offset = 0
avail_numel = full_single_fp32_vector.numel()
for name, shape in shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
# live optimizer object, so we are checking that the numbers are within the right range
align_to = 2 * world_size
def zero2_align(x):
return align_to * math.ceil(x / align_to)
if debug:
print(f"original offset={offset}, avail_numel={avail_numel}")
offset = zero2_align(offset)
avail_numel = zero2_align(avail_numel)
if debug:
print(f"aligned offset={offset}, avail_numel={avail_numel}")
# Sanity check
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
_zero2_merge_frozen_params(state_dict, zero_model_states)
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
if debug:
for i in range(world_size):
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in zero_model_states[0].frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
# merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
if debug:
for i in range(world_size):
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding
avail_numel = fp32_flat_groups[0].numel() * world_size
print(f"Trainable params: Have {avail_numel} numels to process.")
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)
offset += partitioned_numel
offset *= world_size
# Sanity check
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
via a model hub.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
Returns:
- pytorch ``state_dict``
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others
In this example the ``model`` will no longer be usable in the deepspeed context of the same
application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
"""
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
"""
1. Put the provided model to cpu
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
3. Load it into the provided model
Args:
- ``model``: the model object to update
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
Returns:
- ``model`: modified model
Make sure you have plenty of CPU memory available before you call this function. If you don't
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
conveniently placed for you in the checkpoint folder.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
"""
logger.info(f"Extracting fp32 weights")
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
logger.info(f"Overwriting model with fp32 weights")
model = model.cpu()
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint_dir",
type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument(
"output_file",
type=str,
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
parser.add_argument("-t",
"--tag",
type=str,
default=None,
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args()
debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
|