import torch from torch import nn from .lora import FrozenBNBLinear, FrozenBNBEmbedding from .config import GPTJLoraConfig import transformers def add_adapters(model, adapter_dim=16): assert adapter_dim > 0 for module in model.modules(): if isinstance(module, FrozenBNBLinear): module.adapter = nn.Sequential( nn.Linear(module.in_features, adapter_dim, bias=False), nn.Linear(adapter_dim, module.out_features, bias=False), ) nn.init.zeros_(module.adapter[1].weight) elif isinstance(module, FrozenBNBEmbedding): module.adapter = nn.Sequential( nn.Embedding(module.num_embeddings, adapter_dim), nn.Linear(adapter_dim, module.embedding_dim, bias=False), ) nn.init.zeros_(module.adapter[1].weight) def convert_to_int8(model): """Convert linear and embedding modules to 8-bit with optional adapters""" for module in list(model.modules()): for name, child in module.named_children(): if isinstance(child, nn.Linear): setattr( module, name, FrozenBNBLinear( weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8), absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), code=torch.zeros(256), bias=child.bias, ), ) elif isinstance(child, nn.Embedding): setattr( module, name, FrozenBNBEmbedding( weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8), absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), code=torch.zeros(256), ) ) class GPTJLoraBlock(transformers.models.gptj.modeling_gptj.GPTJBlock): config_class = GPTJLoraConfig def __init__(self, config): super().__init__(config) self.config_class = GPTJLoraConfig convert_to_int8(self.attn) convert_to_int8(self.mlp) class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel): config_class = GPTJLoraConfig def __init__(self, config): super().__init__(config) self.config_class = GPTJLoraConfig convert_to_int8(self) class GPTJLoraForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM): config_class = GPTJLoraConfig def __init__(self, config): super().__init__(config) self.config_class = GPTJLoraConfig convert_to_int8(self) if config.add_apapters: add_adapters(self) transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJLoraBlock # monkey-patch GPT-J