|
from transformers import LlamaForCausalLM |
|
import torch |
|
from torch import nn |
|
|
|
class ScaledLinear(nn.Linear): |
|
def __init__(self, in_features, out_features, bias=True): |
|
super().__init__(in_features, out_features, bias=bias) |
|
self.output_scales = nn.Parameter(torch.ones((1, out_features))) |
|
assert bias == False, "bias not supported yet" |
|
|
|
def forward(self, x): |
|
return super().forward(x) * self.output_scales |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLamaNuGPTQForCausalLM(LlamaForCausalLM): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def replace_linear_modules(module): |
|
for name, mod in module.named_children(): |
|
if isinstance(mod, nn.Linear) and name in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]: |
|
setattr(module, name, ScaledLinear(mod.in_features, mod.out_features, mod.bias is not None)) |
|
else: |
|
replace_linear_modules(mod) |
|
replace_linear_modules(self) |
|
|
|
|