Spaces:
Runtime error
Runtime error
import torch | |
from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union, Dict | |
from pytorch_lightning import strategies | |
from lightning_fabric.utilities.types import _PATH | |
from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict | |
''' | |
overwrite the function in deepspeed | |
''' | |
### start overwrite ### | |
def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): | |
sd = self.module.state_dict(destination, prefix, keep_vars) | |
# Remove frozen parameter weights from state_dict if specified | |
if exclude_frozen_parameters: | |
to_be_removed = [] | |
for n in sd: | |
try: | |
if not self.module.get_parameter(n).requires_grad: | |
to_be_removed.append(n) | |
except AttributeError: | |
to_be_removed.append(n) | |
for key in to_be_removed: | |
sd.pop(key) | |
if self.random_ltd_enabled(): | |
sd = remove_random_ltd_state_dict(sd) | |
return sd | |
from deepspeed import DeepSpeedEngine | |
DeepSpeedEngine.module_state_dict = module_state_dict | |
### end overwrite ### | |
class MyDeepSpeedStrategy(strategies.DeepSpeedStrategy): | |
def save_checkpoint_v1( | |
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None | |
): | |
"""Save model/training states as a checkpoint file through state-dump and file-write. | |
Args: | |
checkpoint: dict containing model and trainer state | |
filepath: write-target file's path | |
storage_options: parameter for how to save to st | |
orage, passed to ``CheckpointIO`` plugin | |
""" | |
if self.is_global_zero: | |
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) | |
def load_model_state_dict(self, checkpoint): | |
assert self.lightning_module is not None | |
self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=False) | |
def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: | |
"""Save model/training states as a checkpoint file through state-dump and file-write. | |
Args: | |
checkpoint: The checkpoint state dictionary | |
filepath: write-target file's path | |
storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used | |
Raises: | |
TypeError: | |
If ``storage_options`` arg is passed in | |
""" | |
# broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath | |
filepath = self.broadcast(filepath) | |
if storage_options is not None: | |
raise TypeError( | |
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" | |
f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used." | |
) | |
if self.zero_stage_3 and self._multi_device and self.is_global_zero: | |
print( | |
"Warning: When saving the DeepSpeed Stage 3 checkpoint, " | |
"each worker will save a shard of the checkpoint within a directory. " | |
"If a single file is required after training, " | |
"see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#" | |
"deepspeed-zero-stage-3-single-file for instructions." | |
) | |
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes | |
# dump states as a checkpoint dictionary object | |
_exclude_keys = ["state_dict", "optimizer_states"] | |
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} | |
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint", exclude_frozen_parameters=True) | |