# coding=utf-8 # Copyright 2023 HuggingFace Inc. team and GPTQ and AutoGPTQ authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from enum import Enum from logging import getLogger from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from tqdm.auto import tqdm from transformers import AutoTokenizer from transformers.pytorch_utils import Conv1D from transformers.utils.quantization_config import QuantizationMethod from ..utils import is_accelerate_available, is_auto_gptq_available from ..utils.modeling_utils import recurse_getattr from .constants import GPTQ_CONFIG from .data import get_dataset, prepare_dataset from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen from collections import OrderedDict if is_accelerate_available(): from accelerate import ( cpu_offload_with_hook, load_checkpoint_and_dispatch, ) from accelerate.hooks import remove_hook_from_module if is_auto_gptq_available(): from auto_gptq import exllama_set_max_input_length from auto_gptq.modeling._utils import autogptq_post_init from auto_gptq.quantization import GPTQ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear logger = getLogger(__name__) class ExllamaVersion(int, Enum): ONE = 1 TWO = 2 class GPTQQuantizer(object): r""" A simple API for GPTQ Quantization """ def __init__( self, bits: int, dataset: Optional[Union[List[str], str]] = None, group_size: int = 128, damp_percent: float = 0.1, desc_act: bool = False, sym: bool = True, true_sequential: bool = True, use_cuda_fp16: bool = False, model_seqlen: Optional[int] = None, block_name_to_quantize: Optional[str] = None, module_name_preceding_first_block: Optional[List[str]] = None, batch_size: int = 1, pad_token_id: Optional[int] = None, disable_exllama: bool = False, exllama_config: Dict[str, Any] = None, max_input_length: Optional[int] = None, cache_block_outputs: Optional[bool] = True, modules_in_block_to_quantize: Optional[List[List[str]]] = None, *args, **kwargs, ): """ Args: bits (`int`): The number of bits to quantize to, supported numbers are (2, 3, 4, 8). dataset (`Union[List[str], str, Any]`, defaults to `None`): The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data (e.g. [{ "input_ids": [ 1, 100, 15, ... ],"attention_mask": [ 1, 1, 1, ... ]},...]) or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new']. group_size (int, defaults to 128): The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. damp_percent (`float`, defaults to `0.1`): The percent of the average Hessian diagonal to use for dampening, recommended value is 0.1. desc_act (`bool`, defaults to `False`): Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly speed up inference but the perplexity may become slightly worse. Also known as act-order. sym (`bool`, defaults to `True`): Whether to use symetric quantization. true_sequential (`bool`, defaults to `True`): Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers. use_cuda_fp16 (`bool`, defaults to `False`): Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. model_seqlen (`Optional[int]`, defaults to `None`): The maximum sequence length that the model can take. block_name_to_quantize (`Optional[str]`, defaults to `None`): The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers) module_name_preceding_first_block (`Optional[List[str]]`, defaults to `None`): The layers that are preceding the first Transformer block. batch_size (`int`, defaults to `1`): The batch size of the dataset pad_token_id (`Optional[int]`, defaults to `None`): The pad token id. Needed to prepare the dataset when `batch_size` > 1. disable_exllama (`bool`, defaults to `False`): Whether to use exllama backend. Only works with `bits` = 4. exllama_config (`Dict[str, Any]`, *optional*): The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults to `{"version": 2}` if unset. max_input_length (`Optional[int]`, defaults to `None`): The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length. It is specific to the exllama backend with act-order. cache_block_outputs (`bool`, defaults to `True`): Whether to cache block outputs to reuse as inputs for the succeeding block. It allows optimization of non-standard models (e.g. ChatGLM) but can require more time. modules_in_block_to_quantize (`Optional[List[List[str]]]`, defaults to `None`): List list of module names to quantize in the block specified. This argument is useful to exclude certain linear modules from being quantized. The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. If not set, we will quantize all linear layers. Example: `inside_layer_modules=[["self_attention.query_key_value"], ["mlp.dense_h_to_4h"]]` """ self.bits = bits self.dataset = dataset self.group_size = group_size self.damp_percent = damp_percent self.desc_act = desc_act self.sym = sym self.true_sequential = true_sequential self.use_cuda_fp16 = use_cuda_fp16 self.model_seqlen = model_seqlen self.block_name_to_quantize = block_name_to_quantize self.module_name_preceding_first_block = module_name_preceding_first_block self.batch_size = batch_size self.pad_token_id = pad_token_id self.disable_exllama = disable_exllama self.exllama_config = exllama_config self.max_input_length = max_input_length self.quant_method = QuantizationMethod.GPTQ self.cache_block_outputs = cache_block_outputs self.modules_in_block_to_quantize = modules_in_block_to_quantize self.serialization_keys = [ "bits", "dataset", "group_size", "damp_percent", "desc_act", "sym", "true_sequential", "quant_method", "modules_in_block_to_quantize", ] if self.bits not in [2, 3, 4, 8]: raise ValueError("only support quantize to [2,3,4,8] bits.") if self.group_size != -1 and self.group_size <= 0: raise ValueError("group_size must be greater than 0 or equal to -1") if not (0 < self.damp_percent < 1): raise ValueError("damp_percent must between 0 and 1.") if self.exllama_config is None: self.exllama_config = {"version": ExllamaVersion.TWO} else: if "version" not in self.exllama_config: raise ValueError("`exllama_config` needs to have a `version` key") elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: version = self.exllama_config["version"] raise ValueError( f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {version}" ) self.exllama_version = self.exllama_config["version"] def to_dict(self): """ Returns the args in dict format. """ gptq_dict = {} for key in self.serialization_keys: gptq_dict[key] = getattr(self, key) return gptq_dict @classmethod def from_dict(cls, config_dict: Dict[str, Any]): """ Instantiates a `GPTQQuantizer` using config_dict as kwargs Args: config_dict (`Dict[str,Any]`): quantization config Returns: `GPTQQuantizer`: The quantizer object instantiated from those parameters. """ return cls(**config_dict) def convert_model(self, model: nn.Module): """ Convert the model to a GPTQ model by getting and replacing the layers. Args: model (`nn.Module`): Model to be converted """ if self.block_name_to_quantize is None: self.block_name_to_quantize = get_block_name_with_pattern(model) block_name = self.block_name_to_quantize layers_to_be_replaced = get_layers(model, prefix=block_name) if self.modules_in_block_to_quantize is not None: layers_to_keep = sum(self.modules_in_block_to_quantize, []) for name in list(layers_to_be_replaced.keys()): if not any(name.endswith(layer) for layer in layers_to_keep): logger.info( f"Quantization disabled for {name} (only modules_in_block_to_quantize={self.modules_in_block_to_quantize} are quantized)" ) del layers_to_be_replaced[name] self._replace_by_quant_layers(model, layers_to_be_replaced) return model def get_no_split_module_classes(self, model): """ Get the modules that should not be split across multiple devices. Args: model (`nn.Module`): The input model """ block_class_name = recurse_getattr(model, self.block_name_to_quantize)[0].__class__.__name__ no_split_module_classes = [block_class_name] return no_split_module_classes def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: str = ""): """ Replaces linear layers in `module` by `QuantLinear` Args: module (`nn.Module`): Module to quantize names (`List[str]`): List of names of the module to quantize name (`str`, defaults to `""`): To keep track of the name of the current module """ QuantLinear = dynamically_import_QuantLinear( use_triton=False, desc_act=self.desc_act, group_size=self.group_size, bits=self.bits, disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, ) if isinstance(module, QuantLinear): return for attr in dir(module): layer = getattr(module, attr) name1 = name + "." + attr if name != "" else attr if name1 in names: device = get_device(layer) delattr(module, attr) if isinstance(layer, nn.Linear): in_features = layer.in_features out_features = layer.out_features elif isinstance(layer, nn.Conv2d): in_features = layer.in_channels out_features = layer.out_channels elif isinstance(layer, Conv1D): in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] bias = layer.bias is not None if not (self.desc_act) or self.group_size == -1: new_layer = QuantLinear( self.bits, self.group_size, in_features, out_features, bias, use_cuda_fp16=self.use_cuda_fp16, weight_dtype=layer.weight.dtype, ) else: new_layer = QuantLinear( self.bits, self.group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype ) new_layer.device = device setattr(module, attr, new_layer.to(device)) for name1, child in module.named_children(): self._replace_by_quant_layers(child, names, name + "." + name1 if name != "" else name1) @torch.no_grad() def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): """ Quantizes the model using the dataset Args: model (`nn.Module`): The model to quantize tokenizer (Optional[`Any`], defaults to `None`): The tokenizer to use in order to prepare the dataset. You can pass either: - A custom tokenizer object. - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. Returns: `nn.Module`: The quantized model """ if not is_auto_gptq_available(): raise RuntimeError("auto-gptq is required in order to perform quantzation : `pip install auto-gptq`") if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed to quantize model.") model.eval() # For Transformer model has_config = False has_device_map = False if hasattr(model, "config"): has_config = True use_cache = model.config.use_cache model.config.use_cache = False # If the model has a device_map, we don't move to model. We have already dispatched the hook that will do the work if hasattr(model, "hf_device_map"): devices = list(model.hf_device_map.values()) has_device_map = True if "disk" in devices: raise ValueError("disk offload is not supported with GPTQ quantization") if "cpu" in devices or torch.device("cpu") in devices: if len(model.hf_device_map) > 1: logger.info("Cpu offload is not recommended. There might be some issues with the memory") hook = None for name, device in model.hf_device_map.items(): if device == "cpu": module = recurse_getattr(model, name) remove_hook_from_module(module, recurse=True) module, hook = cpu_offload_with_hook(module, prev_module_hook=hook) else: has_device_map = False if hasattr(model, "dtype"): self.use_cuda_fp16 = model.dtype == torch.float16 if self.model_seqlen is None: # We allow a max value of 4028 to avoid passing data with huge length to the model during the calibration step self.model_seqlen = min(4028, get_seqlen(model)) device = get_device(model) # Step 1: Prepare the data if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str): dataset = self.dataset logger.info("GPTQQuantizer dataset appears to be already tokenized. Skipping tokenization.") else: if isinstance(tokenizer, str): try: tokenizer = AutoTokenizer.from_pretrained(tokenizer) except Exception: raise ValueError( f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained` with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input. For now, we only support quantization for text model. Support for vision, speech and multimodel will come later.""" ) if self.dataset is None: raise ValueError("You need to pass `dataset` in order to quantize your model") elif isinstance(self.dataset, str): dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train") elif isinstance(self.dataset, list): dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset] else: raise ValueError( f"You need to pass a list of string, a list of tokenized data or a string for `dataset`. Found: {type(self.dataset)}." ) dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size) # Step 2: get the input of the 1st block # To do that, we need to put the modules preceding the first block on the same device as the first bloc. # Then we run the model and it will stop at the first bloc as we added a prehook that raise an Exception after storing the inputs. layer_inputs = [] layer_outputs = [] layer_input_kwargs = [] if self.block_name_to_quantize is None: self.block_name_to_quantize = get_block_name_with_pattern(model) if self.module_name_preceding_first_block is None: self.module_name_preceding_first_block = get_preceding_modules(model, self.block_name_to_quantize) blocks = recurse_getattr(model, self.block_name_to_quantize) if not has_device_map: # put modules from module_name_preceding_first_block on cuda for module_name in self.module_name_preceding_first_block: module = recurse_getattr(model, module_name) if module is None: raise ValueError(f"Module {module_name} was not found in model") module = module.to(0) blocks[0] = blocks[0].to(0) def store_input_hook(_, input, *args): kwargs = args[0] if input is None: if "hidden_states" in kwargs: input = (kwargs["hidden_states"],) else: raise ValueError("No input value found in the foward pass") layer_inputs.append(input) other_kwargs = {} for k, v in kwargs.items(): # make sure other arguments also be captured if k not in ["hidden_states"]: other_kwargs[k] = v layer_input_kwargs.append(other_kwargs) raise ValueError if self.cache_block_outputs: handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu data[k] = v.to(0) try: model(**data) except ValueError: pass handle.remove() if not has_device_map: blocks[0].to(device) for module_name in self.module_name_preceding_first_block: module = recurse_getattr(model, module_name) if module is None: raise ValueError(f"Module {module_name} was not found in model") torch.cuda.empty_cache() # Step 3: Quantize the blocks quantizers = {} for i, block in enumerate(tqdm(blocks, desc=f"Quantizing {self.block_name_to_quantize} blocks ")): logger.info(f"Start quantizing block {self.block_name_to_quantize} {i + 1}/{len(blocks)}") if not self.cache_block_outputs: handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True) for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu data[k] = v.to(0) try: model(**data) except ValueError: pass handle.remove() # move block to cuda if needed # in case we have offload modules, we need to put them on cuda because of GPTQ object if not has_device_map or get_device(block) == torch.device("cpu"): block = block.to(0) layers = get_layers(block) if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0: if self.true_sequential: layers_name_list = self.modules_in_block_to_quantize else: layers_name_list = [sum(self.modules_in_block_to_quantize, [])] else: if self.true_sequential: # lazy sequential but works well layers_name_list = [[key] for key in layers.keys()] else: layers_name_list = [list(layers.keys())] logger.info(f"Module to quantize {layers_name_list}") for subset_name_list in tqdm(layers_name_list, leave=False, desc="Quantizing layers inside the block"): subset_layers = {name: layers[name] for name in subset_name_list} gptq = {} handles = [] # add hook for each layer in subset_layers for name in subset_layers: gptq[name] = GPTQ(subset_layers[name]) gptq[name].quantizer.configure(bits=self.bits, sym=self.sym, perchannel=True) def add_batch(name): def tmp(_, input, output): gptq[name].add_batch(input[0].data, output.data) return tmp # because it adding a hook will replace the old one. handles.append(subset_layers[name].register_forward_hook(add_batch(name))) # update Hessian for each layer in subset_layers thanks to the hook for j in range(len(dataset)): # the args are already on the gpu # don't need to store the output block(*layer_inputs[j], **layer_input_kwargs[j]) # remove hook for h in handles: h.remove() for name in subset_name_list: logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...") scale, zero, g_idx = gptq[name].fasterquant( percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act ) quantizers[f"{self.block_name_to_quantize}.{i}.{name}"] = ( gptq[name].quantizer, scale, zero, g_idx, ) gptq[name].free() del subset_layers # we get the new output from the partial quantized block if self.cache_block_outputs: for j in range(len(dataset)): layer_output = block(*layer_inputs[j], **layer_input_kwargs[j]) layer_outputs.append(layer_output) # put back to device if not has_device_map: blocks[i] = block.to(device) del layers del layer_inputs layer_inputs, layer_outputs = layer_outputs, [] else: del layers del layer_inputs layer_inputs = [] torch.cuda.empty_cache() if i==5: break if self.bits == 4: # device not on gpu if device == torch.device("cpu") or (has_device_map and any(d in devices for d in ["cpu", "disk"])): if not self.disable_exllama: logger.warning( "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`" ) self.disable_exllama = True # act order and exllama elif self.desc_act and not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE: logger.warning( "Using Exllama backend with act_order will reorder the weights offline, thus you will not be able to save the model with the right weights." "Setting `disable_exllama=True`. You should only use Exllama backend with act_order for inference. " ) self.disable_exllama = True elif not self.disable_exllama and self.exllama_version == ExllamaVersion.TWO: logger.warning( "Using Exllamav2 backend will reorder the weights offline, thus you will not be able to save the model with the right weights." "Setting `disable_exllama=True`. You should only use Exllamav2 backend for inference. " ) self.disable_exllama = True # Step 4: Pack the model at the end (Replacing the layers) self.pack_model(model=model, quantizers=quantizers) model.is_quantized = True model.quantization_method = QuantizationMethod.GPTQ if has_config: model.config.use_cache = use_cache model.config.quantization_config = self.to_dict() # Step 5: Any post-initialization that require device information, for example buffers initialization on device. model = self.post_init_model(model) torch.cuda.empty_cache() return model def post_init_model(self, model): """ Post-initialization that require device information, for example buffers initialization on device. Args: model (`nn.Module`): The input model """ if self.bits == 4 and not self.disable_exllama: if get_device(model) == torch.device("cpu") or ( hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"]) ): raise ValueError( "Found modules on cpu/disk. Using Exllama or Exllamav2 backend requires all the modules to be on GPU." "You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object" ) class StoreAttr(object): pass model.quantize_config = StoreAttr() model.quantize_config.desc_act = self.desc_act model = autogptq_post_init(model, use_act_order=self.desc_act) if ( self.desc_act and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE) and self.max_input_length is not None ): model = exllama_set_max_input_length(model, self.max_input_length) return model def pack_model( self, model: nn.Module, quantizers: Dict[str, Tuple], ): """ Pack the model by replacing the layers by quantized layers Args: model (`nn.Module`): The model to pack quantizers (`Dict[str,Tuple]`): A mapping of the layer name and the data needed to pack the layer """ QuantLinear = dynamically_import_QuantLinear( use_triton=False, desc_act=self.desc_act, group_size=self.group_size, bits=self.bits, disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, ) logger.info("Packing model...") layers = get_layers(model) layers = {n: layers[n] for n in quantizers} self._replace_by_quant_layers(model, quantizers) qlayers = get_layers(model, [QuantLinear]) autogptq_blobs = OrderedDict() for i, name in enumerate(qlayers): logger.info(name) quantizers[name], scale, zero, g_idx = quantizers[name] # so far can only pack layer on CPU layer_device = qlayers[name].device qlayers[name].to("cpu") layers[name], scale, zero, g_idx = layers[name].to("cpu"), scale.to("cpu"), zero.to("cpu"), g_idx.to("cpu") autogptq_blobs[name] = { "prepack": dict( w=layers[name].weight, b=layers[name].bias, scale=scale, zero=zero, g_idx=g_idx ) } qlayers[name].pack(layers[name], scale, zero, g_idx) autogptq_blobs[name]["pack"] = dict( qweight=qlayers[name].qweight, bias=qlayers[name].bias, scales=qlayers[name].scales, qzeros=qlayers[name].qzeros, g_idx=qlayers[name].g_idx, intweight=qlayers[name].intweight ) qlayers[name].to(layer_device) if i==5: break torch.save(autogptq_blobs, "./opt-125m-gptq4.pth") exit() logger.info("Model packed.") def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True): """ Save model state dict and configs Args: model (`nn.Module`): Model to be saved. The model can be wrapped or unwraped. save_dir (`str`): Directory to which to save. Will be created if it doesn't exist. max_shard_size (`str`, defaults to `"10GB"`): The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard which will be bigger than `max_shard_size`. safe_serialization (`bool`, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ os.makedirs(save_dir, exist_ok=True) model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2) def load_quantized_model( model: nn.Module, save_folder: str, quant_config_name: str = GPTQ_CONFIG, state_dict_name: Optional[str] = None, device_map: Optional[str] = None, max_memory: Optional[Dict] = None, no_split_module_classes: Optional[Dict] = None, offload_folder: Optional[str] = None, offload_buffers: Optional[str] = None, offload_state_dict: bool = False, disable_exllama: bool = False, exllama_config: Optional[Dict[str, Any]] = None, max_input_length: Optional[int] = None, ): """ Load quantized weights from the save_folder into the converted model and dispatch the weights according to the device_map. Args: model (`nn.Module`): The model can be enpty or not. save_folder (`str`): Directory to which to load the weights. quant_config_name (`str`, defaults to `GPTQ_CONFIG`): Name of the quantization config file state_dict_name (`Optional[str]`, defaults to `None`): Name of the state dict file device_map (`Optional[str]`, defaults to `None`): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. max_memory (`Optional[Dict]`, defaults to `None`): A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. no_split_module_classes (`Optional[Dict]`, defaults to `None`): A list of layer class names that should never be split across device (for instance any layer that has a residual connection). offload_folder (`Optional[str]`, defaults to `None`): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_buffers (`Optional[str]`, defaults to `None`): In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters. offload_state_dict (`bool`, defaults to `False`): If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map picked contains `"disk"` values. disable_exllama (`Optional[bool]`, defaults to `None`): Whether to use exllama backend. Only works with `bits` = 4. exllama_config (`Optional[Dict[str, Any]]`, defaults to `None`): The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults to `{"version": 2}` if unset. max_input_length (`Optional[int]`, defaults to `None`): The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length. It is specific to the exllama backend with act-order. Returns: `nn.Module`: The quantized model """ if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed to run quantized model.") if not is_auto_gptq_available(): raise RuntimeError("auto-gptq is required in order to load quantized weights : `pip install auto-gptq`") if not is_accelerate_available(): raise RuntimeError( "You need to install accelerate in order to load and dispatch weights to" "a quantized model. You can do it with `pip install accelerate`" ) if device_map is None: device_map = {"": torch.cuda.current_device()} logger.info("The device_map was not initialized." "Setting device_map to `{'':torch.cuda.current_device()}`.") if exllama_config is None: exllama_config = {"version": ExllamaVersion.TWO} else: if "version" not in exllama_config: raise ValueError("`exllama_config` needs to have a `version` key") elif exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: version = exllama_config["version"] raise ValueError( f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {version}" ) # this branch will check if model is from huggingface try: if hasattr(model, "config") and hasattr(model.config, "quantization_config"): quantize_config_dict = model.config.quantization_config.to_dict() else: with open(os.path.join(save_folder, quant_config_name), "r", encoding="utf-8") as f: quantize_config_dict = json.load(f) except Exception as err: raise ValueError( f"Failed to load quantization config from {save_folder} (lookup for traceback): {err}\nTip: If the save directory is saved from a transformers.PreTrainedModel, make sure that `config.json` contains a 'quantization_config' key." ) from err quantizer = GPTQQuantizer.from_dict(quantize_config_dict) quantizer.disable_exllama = disable_exllama quantizer.exllama_config = exllama_config quantizer.exllama_version = quantizer.exllama_config["version"] quantizer.max_input_length = max_input_length model = quantizer.convert_model(model) if no_split_module_classes is None: no_split_module_classes = quantizer.get_no_split_module_classes(model) model = load_checkpoint_and_dispatch( model, checkpoint=os.path.join(save_folder, state_dict_name) if state_dict_name is not None else save_folder, device_map=device_map, max_memory=max_memory, no_split_module_classes=no_split_module_classes, offload_folder=offload_folder, offload_buffers=offload_buffers, offload_state_dict=offload_state_dict, ) model = quantizer.post_init_model(model) model.is_quantized = True model.quantization_method = QuantizationMethod.GPTQ model.eval() return model