from functools import partial from typing import Iterator, Tuple import torch from torch import nn import torch.nn.utils.parametrize as parametrize import math from torch.nn import Parameter from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig def initialized_weights(shape, num_adaptions, init='kaiming'): weight_data = [] for _ in range(num_adaptions): new_adaption = torch.zeros(shape) if init == 'kaiming': nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5)) elif init == 'normal': nn.init.normal_(new_adaption) else: raise NotImplementedError weight_data.append(new_adaption) return torch.stack(weight_data, dim=0) class LoRAParametrization(nn.Module): def __init__(self, fan_in, fan_out, layer_type='linear', num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): super().__init__() # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings fan_in_fan_out = (layer_type == 'embedding') self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x) if layer_type == 'linear': self.lora_A = nn.Parameter(initialized_weights((rank, fan_in), num_adaptions, init='kaiming')) self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank))) elif layer_type == 'embedding': self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank))) self.lora_B = nn.Parameter(initialized_weights((rank, fan_out), num_adaptions=num_adaptions, init='normal')) else: raise NotImplementedError self.lora_alpha, self.rank = lora_alpha, rank self.scaling = lora_alpha / rank self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x self.register_buffer("lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype), persistent=False) self.forward_fn = lambda x: x self.current_task = None def _dropout(self, A): # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x return A * self.lora_dropout(self.lora_dropout_mask) def lora_forward(self, X): assert self.current_task is not None return X + torch.matmul(*self.swap((self.lora_B[self.current_task], self.dropout_fn(self.lora_A[self.current_task])))).view(X.shape) * self.scaling def forward(self, X): return self.forward_fn(X) def select_task(self, task=None): self.current_task = task if task is None: self.forward_fn = lambda x: x else: self.forward_fn = self.lora_forward @classmethod def from_linear(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): fan_out, fan_in = layer.weight.shape return cls( fan_in, fan_out, num_adaptions=num_adaptions, layer_type='linear', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha ) @classmethod def from_embedding(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): fan_in, fan_out = layer.weight.shape return cls( fan_in, fan_out, num_adaptions=num_adaptions, layer_type='embedding', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha ) @classmethod def add_to_layer(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): if isinstance(layer, nn.Linear): parametrize.register_parametrization(layer, "weight", cls.from_linear(layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha)) elif isinstance(layer, nn.Embedding): parametrize.register_parametrization(layer, "weight", cls.from_embedding(layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha)) @classmethod def select_task_for_layer(cls, layer, task_idx=None): if isinstance(layer, LoRAParametrization): layer.select_task(task_idx) class BertLoRA(BertPreTrainedModel): def __init__(self, config: JinaBertConfig, add_pooling_layer=True, num_adaptions=1): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=add_pooling_layer) self._register_lora(num_adaptions) for name, param in super().named_parameters(): if 'lora' not in name: param.requires_grad_(False) def from_bert(self, *args, num_adaptions=1, **kwargs): self.bert = BertModel.from_pretrained(*args, **kwargs) self._register_lora(num_adaptions) def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): self.apply(partial(LoRAParametrization.add_to_layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha)) def select_task(self, task_idx): self.apply(partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)) def forward(self, *args, **kwargs): return self.bert(*args, **kwargs) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: for _, param in self.named_parameters(recurse=recurse): yield param def named_parameters( self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Parameter]]: for name, param in super().named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate): if 'lora' in name: yield name, param