Llama-2-7b-NuGPTQ / modeling_llama_nugptq.py
smpanaro's picture
Initial commit
c2fa56c verified
raw
history blame
1.59 kB
from transformers import LlamaForCausalLM
import torch
from torch import nn
class ScaledLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias=bias)
self.output_scales = nn.Parameter(torch.ones((1, out_features)))
assert bias == False, "bias not supported yet" # need to divide bias by scales.
def forward(self, x):
return super().forward(x) * self.output_scales
# Works for CPU but not CUDA.
# Starting point if you need to add support for bias.
# def _load_from_state_dict(self, *args, **kwargs):
# # Seems like transformers doesn't call load_state_dict.
# # args[0] - state_dict
# # args[1] - prefix
# args[0][f"{args[1]}output_scales"] = args[0][f"{args[1]}output_scales"].t()
# super()._load_from_state_dict(*args, **kwargs)
# if self.bias is not None:
# self.bias.data = self.bias.data / self.output_scales
class LLamaNuGPTQForCausalLM(LlamaForCausalLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def replace_linear_modules(module):
for name, mod in module.named_children():
if isinstance(mod, nn.Linear) and name in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]:
setattr(module, name, ScaledLinear(mod.in_features, mod.out_features, mod.bias is not None))
else:
replace_linear_modules(mod)
replace_linear_modules(self)