import math import os from functools import partial from typing import Iterator, Optional, Tuple, Union import torch import torch.nn.utils.parametrize as parametrize from torch import nn from torch.nn import Parameter from transformers import PretrainedConfig from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig def initialized_weights( shape: Tuple[int], num_adaptions: int, init: str = "kaiming" ) -> torch.Tensor: weight_data = [] for _ in range(num_adaptions): new_adaption = torch.zeros(shape) if init == "kaiming": nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5)) elif init == "normal": nn.init.normal_(new_adaption) else: raise NotImplementedError weight_data.append(new_adaption) return torch.stack(weight_data, dim=0) class LoRAParametrization(nn.Module): def __init__( self, fan_in: int, fan_out: int, layer_type: str = "linear", num_adaptions: int = 1, rank: int = 4, lora_dropout_p: float = 0.0, lora_alpha: float = 1, ): super().__init__() # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings fan_in_fan_out = layer_type == "embedding" self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x) if layer_type == "linear": self.lora_A = nn.Parameter( initialized_weights((rank, fan_in), num_adaptions, init="kaiming") ) self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank))) elif layer_type == "embedding": self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank))) self.lora_B = nn.Parameter( initialized_weights( (rank, fan_out), num_adaptions=num_adaptions, init="normal" ) ) else: raise NotImplementedError self.lora_alpha, self.rank = lora_alpha, rank self.scaling = lora_alpha / rank self.lora_dropout = ( nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x ) self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x self.register_buffer( "lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype), persistent=False, ) self.forward_fn = lambda x: x self.current_task = None def _dropout(self, A): # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x return A * self.lora_dropout(self.lora_dropout_mask) def lora_forward(self, X): assert self.current_task is not None return ( X + torch.matmul( *self.swap( ( self.lora_B[self.current_task], self.dropout_fn(self.lora_A[self.current_task]), ) ) ).view(X.shape) * self.scaling ) def forward(self, X): return self.forward_fn(X) def select_task(self, task=None): self.current_task = task if task is None: self.forward_fn = lambda x: x else: self.forward_fn = self.lora_forward @classmethod def from_linear( cls, layer: nn.Module, num_adaptions: int = 1, rank: int = 4, lora_dropout_p: float = 0.0, lora_alpha: int = 1, ): assert isinstance(layer, nn.Linear) fan_out, fan_in = layer.weight.shape return cls( fan_in, fan_out, num_adaptions=num_adaptions, layer_type="linear", rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ) @classmethod def from_embedding( cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1 ): assert isinstance(layer, nn.Embedding) fan_in, fan_out = layer.weight.shape return cls( fan_in, fan_out, num_adaptions=num_adaptions, layer_type="embedding", rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ) @classmethod def add_to_layer( cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1 ): if isinstance(layer, nn.Linear): parametrize.register_parametrization( layer, "weight", cls.from_linear( layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ), ) elif isinstance(layer, nn.Embedding): parametrize.register_parametrization( layer, "weight", cls.from_embedding( layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ), ) @classmethod def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None): if isinstance(layer, LoRAParametrization): layer.select_task(task_idx) class BertLoRA(BertPreTrainedModel): def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1): super().__init__(config) if bert is None: self.bert = BertModel(config, add_pooling_layer=add_pooling_layer) else: self.bert = bert self._register_lora(num_adaptions) for name, param in super().named_parameters(): if "lora" not in name: param.requires_grad_(False) self.select_task(0) @classmethod def from_bert(cls, *args, num_adaptions=1, **kwargs): bert = BertModel.from_pretrained(*args, **kwargs) config = JinaBertConfig.from_pretrained(*args, **kwargs) return cls(config, bert=bert, num_adaptions=num_adaptions) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): # TODO: choose between from_bert and super().from_pretrained return cls.from_bert(pretrained_model_name_or_path) def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): self.apply( partial( LoRAParametrization.add_to_layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ) ) def select_task(self, task_idx: Union[None, int]): self.apply( partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx) ) def forward(self, *args, **kwargs): return self.bert(*args, **kwargs) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: for _, param in self.named_parameters(recurse=recurse): yield param def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Parameter]]: for name, param in super().named_parameters( prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate ): if "lora" in name: yield name, param