| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Dict |
|
|
| import torch |
| from compressed_tensors import TRANSFORM_CONFIG_NAME |
| from compressed_tensors.transform import TransformConfig, TransformFactory |
| from compressed_tensors.utils.offload import has_offloaded_params |
|
|
|
|
| __all__ = ["apply_transform_config"] |
|
|
|
|
| def apply_transform_config(model: torch.nn.Module, config: TransformConfig): |
| """ |
| Apply a transform config to a model. Weight transforms are fused into weights, while |
| activation transforms are attached as submodules and trigger via pytorch hooks |
| |
| :param model: model to apply config to |
| :param config: transform config to apply |
| """ |
| for name, scheme in config.config_groups.items(): |
| factory = TransformFactory.from_scheme(scheme, name=name) |
| factory.apply_to_model(model) |
|
|
| |
| setattr(model, TRANSFORM_CONFIG_NAME, config) |
|
|
| |
| |
| |
| _tie_offloaded_tensors(model) |
|
|
|
|
| def _tie_offloaded_tensors(model: torch.nn.Module): |
| """ |
| When accelerate replaces tensors with meta tensors during offloading, the meta |
| tensors may not be identical, even if the offloaded values are identical. |
| |
| However, transformers can only serialize correctly if meta tensors are identical |
| (see transformers#39263). |
| |
| This function collects all meta tensors which have shared offloaded values and sets |
| those tensors to be identical so that they can be removed during serialization |
| |
| :param model: model potentially containing offloaded meta tensors to fix |
| """ |
|
|
| |
| |
| ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() |
| for module in model.modules(): |
| if has_offloaded_params(module): |
| for key, _ in module.named_parameters(recurse=False): |
| offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() |
|
|
| if offloaded_ptr not in ptr_to_meta: |
| ptr_to_meta[offloaded_ptr] = getattr(module, key) |
| setattr(module, key, ptr_to_meta[offloaded_ptr]) |
|
|