Spaces:
Runtime error
Runtime error
# Copyright Forge 2024 | |
import torch | |
import bitsandbytes as bnb | |
from backend import utils | |
from bitsandbytes.nn.modules import Params4bit, QuantState | |
from bitsandbytes.functional import dequantize_4bit | |
def functional_linear_4bits(x, weight, bias): | |
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) | |
out = out.to(x) | |
return out | |
def functional_dequantize_4bit(weight): | |
return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type) | |
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: | |
if state is None: | |
return None | |
device = device or state.absmax.device | |
state2 = ( | |
QuantState( | |
absmax=state.state2.absmax.to(device), | |
shape=state.state2.shape, | |
code=state.state2.code.to(device), | |
blocksize=state.state2.blocksize, | |
quant_type=state.state2.quant_type, | |
dtype=state.state2.dtype, | |
) | |
if state.nested | |
else None | |
) | |
return QuantState( | |
absmax=state.absmax.to(device), | |
shape=state.shape, | |
code=state.code.to(device), | |
blocksize=state.blocksize, | |
quant_type=state.quant_type, | |
dtype=state.dtype, | |
offset=state.offset.to(device) if state.nested else None, | |
state2=state2, | |
) | |
class ForgeParams4bit(Params4bit): | |
def to(self, *args, **kwargs): | |
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) | |
if device is not None and device.type == "cuda" and not self.bnb_quantized: | |
return self._quantize(device) | |
else: | |
return ForgeParams4bit( | |
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), | |
requires_grad=self.requires_grad, | |
quant_state=copy_quant_state(self.quant_state, device), | |
blocksize=self.blocksize, | |
compress_statistics=self.compress_statistics, | |
quant_type=self.quant_type, | |
quant_storage=self.quant_storage, | |
bnb_quantized=self.bnb_quantized, | |
) | |
def pin_memory(self, device=None): | |
return ForgeParams4bit( | |
torch.Tensor.pin_memory(self, device=device), | |
requires_grad=self.requires_grad, | |
quant_state=self.quant_state, | |
blocksize=self.blocksize, | |
compress_statistics=self.compress_statistics, | |
quant_type=self.quant_type, | |
quant_storage=self.quant_storage, | |
bnb_quantized=self.bnb_quantized, | |
) | |
class ForgeLoader4Bit(torch.nn.Module): | |
def __init__(self, *, device, dtype, quant_type, **kwargs): | |
super().__init__() | |
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) | |
self.weight = None | |
self.bias = None | |
self.quant_type = quant_type | |
def _apply(self, fn, recurse=True): | |
if self.weight is not None: | |
self.weight = utils.tensor2parameter(fn(self.weight)) | |
if self.bias is not None: | |
self.bias = utils.tensor2parameter(fn(self.bias)) | |
return self | |
def _save_to_state_dict(self, destination, prefix, keep_vars): | |
super()._save_to_state_dict(destination, prefix, keep_vars) | |
quant_state = getattr(self.weight, "quant_state", None) | |
if quant_state is not None: | |
for k, v in quant_state.as_dict(packed=True).items(): | |
destination[prefix + "weight." + k] = v if keep_vars else v.detach() | |
return | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")} | |
if any('bitsandbytes' in k for k in quant_state_keys): | |
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys} | |
self.weight = ForgeParams4bit.from_prequantized( | |
data=state_dict[prefix + 'weight'], | |
quantized_stats=quant_state_dict, | |
requires_grad=False, | |
device=self.dummy.device, | |
) | |
if prefix + 'bias' in state_dict: | |
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) | |
del self.dummy | |
elif hasattr(self, 'dummy'): | |
if prefix + 'weight' in state_dict: | |
self.weight = ForgeParams4bit( | |
state_dict[prefix + 'weight'].to(self.dummy), | |
requires_grad=False, | |
compress_statistics=False, | |
blocksize=64, | |
quant_type=self.quant_type, | |
quant_storage=torch.uint8, | |
) | |
if prefix + 'bias' in state_dict: | |
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) | |
del self.dummy | |
else: | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
def reload_weight(self, weight): | |
self.weight = ForgeParams4bit( | |
weight, | |
requires_grad=False, | |
compress_statistics=self.weight.compress_statistics, | |
blocksize=self.weight.blocksize, | |
quant_type=self.weight.quant_type, | |
quant_storage=self.weight.quant_storage, | |
bnb_quantized=False | |
) | |
return self | |