Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Robust ZeRO->fp32 converter for Torch>=2.6 (weights_only=True default). | |
| It (1) pre-allowlists common DeepSpeed symbols; (2) on failure, parses the | |
| 'Unsupported global: GLOBAL ...' from the exception, allowlists it, and retries. | |
| Also provides ConvertAfterSaveCallback for use in stage1.py / stage2.py to | |
| run conversion automatically after each checkpoint save when using DeepSpeed. | |
| """ | |
| import argparse | |
| import os | |
| import re | |
| import importlib | |
| from pathlib import Path | |
| def _has_add_safe_globals(): | |
| try: | |
| from torch.serialization import add_safe_globals # noqa: F401 | |
| return True | |
| except Exception: | |
| return False | |
| def _add_safe(objs): | |
| try: | |
| from torch.serialization import add_safe_globals | |
| add_safe_globals(objs) | |
| except Exception: | |
| pass | |
| def _try_import_symbol(qualname: str): | |
| """ | |
| Import 'a.b.c' -> returns object 'c' from module 'a.b'. | |
| Returns None if anything fails. | |
| """ | |
| try: | |
| mod_name, attr = qualname.rsplit('.', 1) | |
| mod = importlib.import_module(mod_name) | |
| return getattr(mod, attr) | |
| except Exception: | |
| return None | |
| def _pre_allowlist_commons(): | |
| # Pre-allowlist common DS symbols seen in ZeRO shards | |
| commons = [ | |
| # FP16 scalers | |
| "deepspeed.runtime.fp16.loss_scaler.LossScaler", | |
| "deepspeed.runtime.fp16.dynamic_loss_scaler.DynamicLossScaler", | |
| # ZeRO enums/config/status | |
| "deepspeed.runtime.zero.config.ZeroStageEnum", | |
| "deepspeed.runtime.zero.stage_1_and_2.ZeroParamStatus", | |
| "deepspeed.runtime.zero.stage_1_and_2.ZeroOptimizerStage2", | |
| "deepspeed.runtime.config.DeepSpeedConfig", | |
| # You just hit this one: | |
| "deepspeed.utils.tensor_fragment.fragment_address", | |
| ] | |
| objs = [] | |
| for qn in commons: | |
| obj = _try_import_symbol(qn) | |
| if obj is not None: | |
| objs.append(obj) | |
| if objs: | |
| _add_safe(objs) | |
| def _extract_unsupported_globals(msg: str): | |
| """ | |
| Parse error text for lines like: | |
| 'Unsupported global: GLOBAL deepspeed.utils.tensor_fragment.fragment_address' | |
| Return list of qualified names. | |
| """ | |
| pats = [ | |
| r"Unsupported global:\s+GLOBAL\s+([A-Za-z0-9_\.]+)", | |
| r"was not an allowed global.*?\[\s*([A-Za-z0-9_\.]+)\s*\]", | |
| ] | |
| found = set() | |
| for pat in pats: | |
| for m in re.finditer(pat, msg): | |
| found.add(m.group(1)) | |
| return list(found) | |
| def convert_zero_to_fp32(ckpt_dir: str, out_path: str, max_retries: int = 5): | |
| from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict | |
| # Step 0: pre-allowlist common DS symbols (no-op on old torch) | |
| if _has_add_safe_globals(): | |
| _pre_allowlist_commons() | |
| # Step 1: try convert; on failure, parse & allowlist missing globals, then retry | |
| last_err = None | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| convert_zero_checkpoint_to_fp32_state_dict(ckpt_dir, out_path) | |
| print(f"[OK] Converted ZeRO checkpoint → {out_path}") | |
| return | |
| except Exception as e: | |
| last_err = e | |
| msg = str(e) | |
| missing = _extract_unsupported_globals(msg) if _has_add_safe_globals() else [] | |
| if not missing: | |
| # nothing to auto-allowlist or on old torch -> just bail | |
| break | |
| objs = [] | |
| for qn in missing: | |
| obj = _try_import_symbol(qn) | |
| if obj is not None: | |
| objs.append(obj) | |
| if objs: | |
| _add_safe(objs) | |
| print(f"[Retry {attempt}/{max_retries}] allowlisted: {', '.join(missing)}; retrying…") | |
| continue | |
| else: | |
| # couldn't import any of them | |
| break | |
| # If we reach here, conversion failed | |
| raise last_err | |
| def _convert_after_save_callback_class(run_after_train_epoch): | |
| """Build a PLC Callback class that runs convert after checkpoint save (DeepSpeed only, rank 0).""" | |
| import pytorch_lightning as pl | |
| class _ConvertAfterSaveCallback(pl.Callback): | |
| def __init__(self, dirpath, save_every_n_epochs): | |
| self.dirpath = dirpath.rstrip(os.sep) | |
| self.save_every_n_epochs = save_every_n_epochs | |
| self._run_after_train = run_after_train_epoch | |
| def _maybe_convert(self, trainer): | |
| if getattr(trainer, 'global_rank', 0) != 0: | |
| return | |
| strategy = getattr(trainer, 'strategy', None) | |
| if strategy is None or 'DeepSpeed' not in type(strategy).__name__: | |
| return | |
| epoch = trainer.current_epoch + 1 | |
| if epoch % self.save_every_n_epochs != 0: | |
| return | |
| for cb in trainer.callbacks: | |
| if type(cb).__name__ == 'ModelCheckpoint': | |
| last_path = getattr(cb, 'last_model_path', None) or getattr(cb, 'best_model_path', None) | |
| if not last_path or not os.path.exists(last_path): | |
| return | |
| out_path = os.path.join(self.dirpath, 'converted.ckpt') | |
| try: | |
| convert_zero_to_fp32(last_path, out_path) | |
| except Exception as e: | |
| print(f"[ConvertAfterSave] Conversion failed: {e}") | |
| return | |
| def on_train_epoch_end(self, trainer, pl_module): | |
| if self._run_after_train: | |
| self._maybe_convert(trainer) | |
| def on_validation_epoch_end(self, trainer, pl_module): | |
| if not self._run_after_train: | |
| self._maybe_convert(trainer) | |
| return _ConvertAfterSaveCallback | |
| def ConvertAfterSaveCallback(dirpath, save_every_n_epochs, run_after_train_epoch=True): | |
| """Callback instance: after each checkpoint save, run ZeRO->fp32 and write dirpath/converted.ckpt.""" | |
| return _convert_after_save_callback_class(run_after_train_epoch)(dirpath, save_every_n_epochs) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input', type=str, required=True, | |
| help='Path to the ZeRO checkpoint folder (…/epoch=XX.ckpt/checkpoint)') | |
| parser.add_argument('--output', type=str, default=None, | |
| help='Path to output fp32 PyTorch state_dict file') | |
| args = parser.parse_args() | |
| ckpt_dir = Path(args.input) | |
| out = Path(args.output) if args.output is not None else (ckpt_dir / 'converted.ckpt') | |
| convert_zero_to_fp32(str(ckpt_dir), str(out)) | |
| if __name__ == '__main__': | |
| main() | |
| # import argparse | |
| # from pathlib import Path | |
| # from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict | |
| # if __name__ == '__main__': | |
| # ## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder') | |
| # parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file') | |
| # # parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint') | |
| # args = parser.parse_args() | |
| # if args.output is None: | |
| # args.output = Path(args.input) / 'converted.ckpt' | |
| # convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output) |