Enkhai's picture
Update gptj.py
16cec6d
raw
history blame
No virus
2.88 kB
import torch
from torch import nn
from .lora import FrozenBNBLinear, FrozenBNBEmbedding
from .config import GPTJLoraConfig
import transformers
from transformers import AutoModelForCausalLM
def add_adapters(model, adapter_dim=16):
assert adapter_dim > 0
for module in model.modules():
if isinstance(module, FrozenBNBLinear):
module.adapter = nn.Sequential(
nn.Linear(module.in_features, adapter_dim, bias=False),
nn.Linear(adapter_dim, module.out_features, bias=False),
)
nn.init.zeros_(module.adapter[1].weight)
elif isinstance(module, FrozenBNBEmbedding):
module.adapter = nn.Sequential(
nn.Embedding(module.num_embeddings, adapter_dim),
nn.Linear(adapter_dim, module.embedding_dim, bias=False),
)
nn.init.zeros_(module.adapter[1].weight)
def convert_to_int8(model):
"""Convert linear and embedding modules to 8-bit with optional adapters"""
for module in list(model.modules()):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(
module,
name,
FrozenBNBLinear(
weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
code=torch.zeros(256),
bias=child.bias,
),
)
elif isinstance(child, nn.Embedding):
setattr(
module,
name,
FrozenBNBEmbedding(
weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
code=torch.zeros(256),
)
)
class GPTJLoraBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self.attn)
convert_to_int8(self.mlp)
class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self)
class GPTJLoraForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
def __init__(self, config):
super().__init__(config)
convert_to_int8(self)
if config.add_apapters:
add_adapters(self)
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJLoraBlock # monkey-patch GPT-J
transformers.AutoConfig.register("gptj-lora", GPTJLoraConfig)
transformers.AutoModelForCausalLM.register(GPTJLoraConfig, GPTJLoraForCausalLM)