File size: 5,754 Bytes
3b609b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..base import DiffusersQuantizer


if TYPE_CHECKING:
    from ...models.modeling_utils import ModelMixin


from ...utils import (
    get_module_from_name,
    is_accelerate_available,
    is_accelerate_version,
    is_gguf_available,
    is_gguf_version,
    is_torch_available,
    logging,
)


if is_torch_available() and is_gguf_available():
    import torch

    from .utils import (
        GGML_QUANT_SIZES,
        GGUFParameter,
        _dequantize_gguf_and_restore_linear,
        _quant_shape_from_byte_shape,
        _replace_with_gguf_linear,
    )


logger = logging.get_logger(__name__)


class GGUFQuantizer(DiffusersQuantizer):
    use_keep_in_fp32_modules = True

    def __init__(self, quantization_config, **kwargs):
        super().__init__(quantization_config, **kwargs)

        self.compute_dtype = quantization_config.compute_dtype
        self.pre_quantized = quantization_config.pre_quantized
        self.modules_to_not_convert = quantization_config.modules_to_not_convert

        if not isinstance(self.modules_to_not_convert, list):
            self.modules_to_not_convert = [self.modules_to_not_convert]

    def validate_environment(self, *args, **kwargs):
        if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
            raise ImportError(
                "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
            )
        if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
            raise ImportError(
                "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
            )

    # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
    def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
        # need more space for buffers that are created during quantization
        max_memory = {key: val * 0.90 for key, val in max_memory.items()}
        return max_memory

    def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
        if target_dtype != torch.uint8:
            logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
        return torch.uint8

    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        if torch_dtype is None:
            torch_dtype = self.compute_dtype
        return torch_dtype

    def check_quantized_param_shape(self, param_name, current_param, loaded_param):
        loaded_param_shape = loaded_param.shape
        current_param_shape = current_param.shape
        quant_type = loaded_param.quant_type

        block_size, type_size = GGML_QUANT_SIZES[quant_type]

        inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
        if inferred_shape != current_param_shape:
            raise ValueError(
                f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
            )

        return True

    def check_if_quantized_param(
        self,
        model: "ModelMixin",
        param_value: Union["GGUFParameter", "torch.Tensor"],
        param_name: str,
        state_dict: Dict[str, Any],
        **kwargs,
    ) -> bool:
        if isinstance(param_value, GGUFParameter):
            return True

        return False

    def create_quantized_param(
        self,
        model: "ModelMixin",
        param_value: Union["GGUFParameter", "torch.Tensor"],
        param_name: str,
        target_device: "torch.device",
        state_dict: Optional[Dict[str, Any]] = None,
        unexpected_keys: Optional[List[str]] = None,
    ):
        module, tensor_name = get_module_from_name(model, param_name)
        if tensor_name not in module._parameters and tensor_name not in module._buffers:
            raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")

        if tensor_name in module._parameters:
            module._parameters[tensor_name] = param_value.to(target_device)
        if tensor_name in module._buffers:
            module._buffers[tensor_name] = param_value.to(target_device)

    def _process_model_before_weight_loading(
        self,
        model: "ModelMixin",
        device_map,
        keep_in_fp32_modules: List[str] = [],
        **kwargs,
    ):
        state_dict = kwargs.get("state_dict", None)

        self.modules_to_not_convert.extend(keep_in_fp32_modules)
        self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]

        _replace_with_gguf_linear(
            model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
        )

    def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
        return model

    @property
    def is_serializable(self):
        return False

    @property
    def is_trainable(self) -> bool:
        return False

    def _dequantize(self, model):
        is_model_on_cpu = model.device.type == "cpu"
        if is_model_on_cpu:
            logger.info(
                "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
            )
            model.to(torch.cuda.current_device())

        model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
        if is_model_on_cpu:
            model.to("cpu")
        return model