|
import torch |
|
import torch.nn as nn |
|
|
|
from .gptq import * |
|
|
|
from .quant import * |
|
from transformers import BloomForCausalLM as LM |
|
|
|
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): |
|
if type(module) in layers: |
|
return {name: module} |
|
res = {} |
|
for name1, child in module.named_children(): |
|
res.update(find_layers( |
|
child, layers=layers, name=name + '.' + name1 if name != '' else name1 |
|
)) |
|
return res |
|
|
|
class SakuraForCausalLM(LM): |
|
def __init__(self,*args,**kwargs): |
|
def noop(*args, **kwargs): |
|
pass |
|
torch.nn.init.kaiming_uniform_ = noop |
|
torch.nn.init.uniform_ = noop |
|
torch.nn.init.normal_ = noop |
|
torch.set_default_dtype(torch.half) |
|
transformers.modeling_utils._init_weights = False |
|
torch.set_default_dtype(torch.half) |
|
super().__init__(*args,**kwargs) |
|
torch.set_default_dtype(torch.float) |
|
self.eval() |
|
layers = find_layers(self) |
|
for name in ['lm_head']: |
|
if name in layers: |
|
del layers[name] |
|
make_quant(self, layers, 4, -1) |