Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
from collections import defaultdict | |
from contextlib import contextmanager | |
from itertools import chain | |
from typing import Dict, List, Optional, Union | |
import torch | |
import torch.nn as nn | |
from mmpretrain.utils import require | |
def dispatch_model( | |
model, | |
device_map: Union[str, dict], | |
max_memory: Optional[dict] = None, | |
no_split_module_classes: Optional[List[str]] = None, | |
offload_folder: str = None, | |
offload_buffers: bool = False, | |
preload_module_classes: Optional[List[str]] = None, | |
): | |
"""Split and dispatch a model across devices. | |
The function depends on the `accelerate` package. Refers to | |
https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling | |
Args: | |
model (torch.nn.Module): The model to dispatch. | |
device_map (str | dict | None): A map that specifies where each | |
submodule should go. It doesn't need to be refined to each | |
parameter/buffer name, once a given module name is inside, every | |
submodule of it will be sent to the same device. You can use | |
`device_map="auto"` to automatically generate the device map. | |
Defaults to None. | |
max_memory (dict | None): A dictionary device identifier to maximum | |
memory. Will default to the maximum memory available for each GPU | |
and the available CPU RAM if unset. Defaults to None. | |
no_split_module_classes (List[str] | None): A list of layer class names | |
that should never be split across device (for instance any layer | |
that has a residual connection). If None, try to get the settings | |
from the model class. Defaults to None. | |
offload_folder (str | None): If the `device_map` contains any value | |
`"disk"`, the folder where we will offload weights. | |
offload_buffers (bool): In the layers that are offloaded on the CPU | |
or the hard drive, whether or not to offload the buffers as | |
well as the parameters. Defaults to False. | |
preload_module_classes (List[str] | None): A list of classes whose | |
instances should load all their weights (even in the submodules) at | |
the beginning of the forward. This should only be used for classes | |
that have submodules which are registered but not called directly | |
during the forward, for instance if a `dense` linear layer is | |
registered, but at forward, `dense.weight` and `dense.bias` are | |
used in some operations instead of calling `dense` directly. | |
Defaults to None. | |
""" | |
from accelerate import dispatch_model, infer_auto_device_map | |
# Check valid device_map string. | |
valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential'] | |
if isinstance(device_map, str) and device_map not in valid_map_option: | |
raise ValueError('If passing a string for `device_map`, please choose ' | |
f'from {valid_map_option}.') | |
# Generate device map automatically | |
if isinstance(device_map, str): | |
if no_split_module_classes is None: | |
no_split_module_classes = getattr(model, '_no_split_modules', None) | |
if no_split_module_classes is None: | |
raise ValueError(f'{model.__class__.__name__} does not support ' | |
f"`device_map='{device_map}'` yet.") | |
if device_map != 'sequential': | |
from accelerate.utils import get_balanced_memory | |
max_memory = get_balanced_memory( | |
model, | |
max_memory=max_memory, | |
no_split_module_classes=no_split_module_classes, | |
dtype=None, | |
low_zero=(device_map == 'balanced_low_0'), | |
) | |
max_memory[0] *= 0.9 | |
device_map = infer_auto_device_map( | |
model, | |
max_memory=max_memory, | |
no_split_module_classes=no_split_module_classes, | |
dtype=None, | |
) | |
if 'disk' in device_map.values(): | |
if offload_folder is None: | |
raise ValueError( | |
'The current `device_map` had weights offloaded to the disk. ' | |
'Please provide an `offload_folder` for them.') | |
os.makedirs(offload_folder, exist_ok=True) | |
main_device = next( | |
(d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu') | |
model = dispatch_model( | |
model, | |
device_map=device_map, | |
main_device=main_device, | |
offload_dir=offload_folder, | |
offload_buffers=offload_buffers, | |
preload_module_classes=preload_module_classes, | |
) | |
if hasattr(model, 'data_preprocessor'): | |
model.data_preprocessor._device = torch.device(main_device) | |
return model | |
def init_empty_weights(include_buffers: bool = False): | |
"""A context manager under which models are initialized with all parameters | |
on the meta device. | |
With this context manager, we can create an empty model. Useful when just | |
initializing the model would blow the available RAM. | |
Besides move the parameters to meta device, this method will also avoid | |
load checkpoint from `mmengine.runner.load_checkpoint` and | |
`transformers.PreTrainedModel.from_pretrained`. | |
Modified from https://github.com/huggingface/accelerate | |
Args: | |
include_buffers (bool): Whether put all buffers on the meta device | |
during initialization. | |
""" | |
device = torch.device('meta') | |
# move parameter and buffer to meta device | |
old_register_parameter = nn.Module.register_parameter | |
if include_buffers: | |
old_register_buffer = nn.Module.register_buffer | |
# See https://github.com/huggingface/accelerate/pull/699 | |
tensor_constructors_to_patch = { | |
torch_function_name: getattr(torch, torch_function_name) | |
for torch_function_name in ['empty', 'zeros', 'ones', 'full'] | |
} | |
def register_parameter(module, name, param): | |
old_register_parameter(module, name, param) | |
if param is not None: | |
param_cls = type(module._parameters[name]) | |
kwargs = module._parameters[name].__dict__ | |
module._parameters[name] = param_cls( | |
module._parameters[name].to(device), **kwargs) | |
def register_buffer(module, name, buffer, *args, **kwargs): | |
old_register_buffer(module, name, buffer, *args, **kwargs) | |
if buffer is not None: | |
module._buffers[name] = module._buffers[name].to(device) | |
def patch_tensor_constructor(fn): | |
def wrapper(*args, **kwargs): | |
kwargs['device'] = device | |
return fn(*args, **kwargs) | |
return wrapper | |
# Patch load_checkpoint | |
import mmengine.runner.checkpoint as mmengine_load | |
old_load_checkpoint = mmengine_load.load_checkpoint | |
def patch_load_checkpoint(*args, **kwargs): | |
return {} | |
# Patch transformers from pretrained | |
try: | |
from transformers import PreTrainedModel | |
from transformers.models.auto.auto_factory import (AutoConfig, | |
_BaseAutoModelClass) | |
with_transformers = True | |
except ImportError: | |
with_transformers = False | |
def patch_auto_model(cls, pretrained_model_name_or_path, *model_args, | |
**kwargs): | |
cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, | |
*model_args, **kwargs) | |
return cls.from_config(cfg) | |
def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args, | |
**kwargs): | |
cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path, | |
*model_args, **kwargs) | |
return cls(cfg) | |
if with_transformers: | |
old_pretrained_model = PreTrainedModel.from_pretrained | |
old_auto_model = _BaseAutoModelClass.from_pretrained | |
try: | |
nn.Module.register_parameter = register_parameter | |
mmengine_load.load_checkpoint = patch_load_checkpoint | |
if with_transformers: | |
PreTrainedModel.from_pretrained = patch_pretrained_model | |
_BaseAutoModelClass.from_pretrained = patch_auto_model | |
if include_buffers: | |
nn.Module.register_buffer = register_buffer | |
for func in tensor_constructors_to_patch.keys(): | |
tensor_constructor = patch_tensor_constructor( | |
getattr(torch, func)) | |
setattr(torch, func, tensor_constructor) | |
yield | |
finally: | |
nn.Module.register_parameter = old_register_parameter | |
mmengine_load.load_checkpoint = old_load_checkpoint | |
if with_transformers: | |
PreTrainedModel.from_pretrained = old_pretrained_model | |
_BaseAutoModelClass.from_pretrained = old_auto_model | |
if include_buffers: | |
nn.Module.register_buffer = old_register_buffer | |
for func, ori in tensor_constructors_to_patch.items(): | |
setattr(torch, func, ori) | |
def compute_module_sizes( | |
model: nn.Module, | |
dtype: Union[str, torch.dtype, None] = None, | |
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None): | |
"""Compute the size of each submodule of a given model.""" | |
def get_dtype(dtype): | |
if isinstance(dtype, str): | |
dtype = getattr(torch, dtype) | |
if dtype is not None: | |
assert issubclass(dtype, torch.dtype) | |
return dtype | |
def dtype_bytes(dtype: torch.dtype): | |
if dtype is torch.bool: | |
return 1 | |
if dtype.is_floating_point: | |
return torch.finfo(dtype).bits / 8 | |
else: | |
return torch.iinfo(dtype).bits / 8 | |
if dtype is not None: | |
dtype = get_dtype(dtype) | |
dtype_size = dtype_bytes(dtype) | |
if special_dtypes is not None: | |
special_dtypes = { | |
key: dtype_bytes(dtype) | |
for key, dtype in special_dtypes.items() | |
} | |
module_sizes = defaultdict(int) | |
for name, tensor in chain( | |
model.named_parameters(recurse=True), | |
model.named_buffers(recurse=True)): | |
if special_dtypes is not None and name in special_dtypes: | |
size = tensor.numel() * special_dtypes[name] | |
elif dtype is None: | |
size = tensor.numel() * tensor.element_size() | |
else: | |
size = tensor.numel() * min(dtype_size, tensor.element_size()) | |
name_parts = name.split('.') | |
for idx in range(len(name_parts) + 1): | |
module_sizes['.'.join(name_parts[:idx])] += size | |
return module_sizes | |