|
import logging |
|
import re |
|
from functools import cache, partial |
|
from typing import Callable, TypeVar |
|
|
|
import deepspeed |
|
import pandas as pd |
|
from deepspeed.accelerator import get_accelerator |
|
from deepspeed.runtime.engine import DeepSpeedEngine |
|
from deepspeed.runtime.utils import clip_grad_norm_ |
|
from torch import nn |
|
|
|
from .distributed import fix_unset_envs |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
T = TypeVar("T") |
|
|
|
|
|
def flatten_dict(d): |
|
records = pd.json_normalize(d, sep="/").to_dict(orient="records") |
|
return records[0] if records else {} |
|
|
|
|
|
def _get_named_modules(module, attrname, sep="/"): |
|
for name, module in module.named_modules(): |
|
name = name.replace(".", sep) |
|
if hasattr(module, attrname): |
|
yield name, module |
|
|
|
|
|
def gather_attribute(module, attrname, delete=True, prefix=None): |
|
ret = {} |
|
for name, module in _get_named_modules(module, attrname): |
|
ret[name] = getattr(module, attrname) |
|
if delete: |
|
try: |
|
delattr(module, attrname) |
|
except Exception as e: |
|
raise RuntimeError(f"{name} {module} {attrname}") from e |
|
if prefix: |
|
ret = {prefix: ret} |
|
ret = flatten_dict(ret) |
|
|
|
ret = {re.sub(r"\/+", "/", k): v for k, v in ret.items()} |
|
return ret |
|
|
|
|
|
def dispatch_attribute(module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None): |
|
for _, module in _get_named_modules(module, attrname): |
|
if filter_fn is None or filter_fn(module): |
|
setattr(module, attrname, value) |
|
|
|
|
|
@cache |
|
def update_deepspeed_logger(): |
|
logger = logging.getLogger("DeepSpeed") |
|
logger.setLevel(logging.WARNING) |
|
|
|
|
|
@cache |
|
def init_distributed(): |
|
update_deepspeed_logger() |
|
fix_unset_envs() |
|
deepspeed.init_distributed(get_accelerator().communication_backend_name()) |
|
|
|
|
|
def _try_each(*fns, e=None): |
|
if len(fns) == 0: |
|
raise RuntimeError("All functions failed") |
|
|
|
head, *tails = fns |
|
|
|
try: |
|
return head() |
|
except Exception as e: |
|
logger.warning(f"Tried {head} but failed: {e}, trying next") |
|
return _try_each(*tails) |
|
|
|
|
|
class Engine(DeepSpeedEngine): |
|
def __init__(self, *args, ckpt_dir, **kwargs): |
|
init_distributed() |
|
super().__init__(args=None, *args, **kwargs) |
|
self._ckpt_dir = ckpt_dir |
|
self._frozen_params = set() |
|
self._fp32_grad_norm = None |
|
|
|
@property |
|
def path(self): |
|
return self._ckpt_dir |
|
|
|
def freeze_(self): |
|
for p in self.module.parameters(): |
|
if p.requires_grad: |
|
p.requires_grad_(False) |
|
self._frozen_params.add(p) |
|
|
|
def unfreeze_(self): |
|
for p in self._frozen_params: |
|
p.requires_grad_(True) |
|
self._frozen_params.clear() |
|
|
|
@property |
|
def global_step(self): |
|
return self.global_steps |
|
|
|
def gather_attribute(self, *args, **kwargs): |
|
return gather_attribute(self.module, *args, **kwargs) |
|
|
|
def dispatch_attribute(self, *args, **kwargs): |
|
return dispatch_attribute(self.module, *args, **kwargs) |
|
|
|
def clip_fp32_gradients(self): |
|
self._fp32_grad_norm = clip_grad_norm_( |
|
parameters=self.module.parameters(), |
|
max_norm=self.gradient_clipping(), |
|
mpu=self.mpu, |
|
) |
|
|
|
def get_grad_norm(self): |
|
grad_norm = self.get_global_grad_norm() |
|
if grad_norm is None: |
|
grad_norm = self._fp32_grad_norm |
|
return grad_norm |
|
|
|
def save_checkpoint(self, *args, **kwargs): |
|
if not self._ckpt_dir.exists(): |
|
self._ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
super().save_checkpoint(save_dir=self._ckpt_dir, *args, **kwargs) |
|
logger.info(f"Saved checkpoint to {self._ckpt_dir}") |
|
|
|
def load_checkpoint(self, *args, **kwargs): |
|
fn = partial(super().load_checkpoint, *args, load_dir=self._ckpt_dir, **kwargs) |
|
return _try_each( |
|
lambda: fn(), |
|
lambda: fn(load_optimizer_states=False), |
|
lambda: fn(load_lr_scheduler_states=False), |
|
lambda: fn(load_optimizer_states=False, load_lr_scheduler_states=False), |
|
lambda: fn( |
|
load_optimizer_states=False, |
|
load_lr_scheduler_states=False, |
|
load_module_strict=False, |
|
), |
|
) |
|
|