import math from functools import partial from typing import Iterator, Optional, Tuple, Union import torch import torch.nn.utils.parametrize as parametrize from torch import nn from torch.nn import Parameter from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig def initialized_weights( shape: Tuple[int], num_adaptions: int, init: str = "kaiming" ) -> torch.Tensor: weight_data = [] for _ in range(num_adaptions): new_adaption = torch.zeros(shape) if init == "kaiming": nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5)) elif init == "normal": nn.init.normal_(new_adaption) else: raise NotImplementedError weight_data.append(new_adaption) return torch.stack(weight_data, dim=0) class LoRAParametrization(nn.Module): def __init__( self, fan_in: int, fan_out: int, layer_type: str = "linear", num_adaptions: int = 1, rank: int = 4, lora_dropout_p: float = 0.0, lora_alpha: float = 1, ): super().__init__() # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings fan_in_fan_out = layer_type == "embedding" self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x) if layer_type == "linear": self.lora_A = nn.Parameter( initialized_weights((rank, fan_in), num_adaptions, init="kaiming") ) self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank))) elif layer_type == "embedding": self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank))) self.lora_B = nn.Parameter( initialized_weights( (rank, fan_out), num_adaptions=num_adaptions, init="normal" ) ) else: raise NotImplementedError self.lora_alpha, self.rank = lora_alpha, rank self.scaling = lora_alpha / rank self.lora_dropout = ( nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x ) self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x self.register_buffer( "lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype), persistent=False, ) self.forward_fn = lambda x: x self.current_task = None def _dropout(self, A): # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x return A * self.lora_dropout(self.lora_dropout_mask) def lora_forward(self, X): assert self.current_task is not None return ( X + torch.matmul( *self.swap( ( self.lora_B[self.current_task], self.dropout_fn(self.lora_A[self.current_task]), ) ) ).view(X.shape) * self.scaling ) def forward(self, X): return self.forward_fn(X) def select_task(self, task=None): self.current_task = task if task is None: self.forward_fn = lambda x: x else: self.forward_fn = self.lora_forward @classmethod def from_linear( cls, layer: nn.Module, num_adaptions: int = 1, rank: int = 4, lora_dropout_p: float = 0.0, lora_alpha: int = 1, ): assert isinstance(layer, nn.Linear) fan_out, fan_in = layer.weight.shape return cls( fan_in, fan_out, num_adaptions=num_adaptions, layer_type="linear", rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ) @classmethod def from_embedding( cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1 ): assert isinstance(layer, nn.Embedding) fan_in, fan_out = layer.weight.shape return cls( fan_in, fan_out, num_adaptions=num_adaptions, layer_type="embedding", rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ) @classmethod def add_to_layer( cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1 ): if isinstance(layer, nn.Linear): parametrize.register_parametrization( layer, "weight", cls.from_linear( layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ), ) elif isinstance(layer, nn.Embedding): parametrize.register_parametrization( layer, "weight", cls.from_embedding( layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ), ) @classmethod def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None): if isinstance(layer, LoRAParametrization): layer.select_task(task_idx) class BertLoRA(BertPreTrainedModel): def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1): super().__init__(config) if bert is None: self.bert = BertModel(config, add_pooling_layer=add_pooling_layer) else: self.bert = bert self._register_lora(num_adaptions) for name, param in super().named_parameters(): if "lora" not in name: param.requires_grad_(False) self.select_task(0) @classmethod def from_bert(cls, *args, num_adaptions=1, **kwargs): bert = BertModel.from_pretrained(*args, **kwargs) config = JinaBertConfig.from_pretrained(*args, **kwargs) return cls(config, bert=bert, num_adaptions=num_adaptions) def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): self.apply( partial( LoRAParametrization.add_to_layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha, ) ) def select_task(self, task_idx: Union[None, int]): self.apply( partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx) ) def forward(self, *args, **kwargs): return self.bert(*args, **kwargs) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: for _, param in self.named_parameters(recurse=recurse): yield param def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Parameter]]: for name, param in super().named_parameters( prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate ): if "lora" in name: yield name, param