|
import math |
|
from functools import partial |
|
from typing import Iterator, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.utils.parametrize as parametrize |
|
from torch import nn |
|
from torch.nn import Parameter |
|
|
|
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig |
|
|
|
|
|
def initialized_weights( |
|
shape: Tuple[int], num_adaptions: int, init: str = "kaiming" |
|
) -> torch.Tensor: |
|
weight_data = [] |
|
for _ in range(num_adaptions): |
|
new_adaption = torch.zeros(shape) |
|
if init == "kaiming": |
|
nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5)) |
|
elif init == "normal": |
|
nn.init.normal_(new_adaption) |
|
else: |
|
raise NotImplementedError |
|
weight_data.append(new_adaption) |
|
return torch.stack(weight_data, dim=0) |
|
|
|
|
|
class LoRAParametrization(nn.Module): |
|
def __init__( |
|
self, |
|
fan_in: int, |
|
fan_out: int, |
|
layer_type: str = "linear", |
|
num_adaptions: int = 1, |
|
rank: int = 4, |
|
lora_dropout_p: float = 0.0, |
|
lora_alpha: float = 1, |
|
): |
|
super().__init__() |
|
|
|
|
|
fan_in_fan_out = layer_type == "embedding" |
|
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x) |
|
|
|
if layer_type == "linear": |
|
self.lora_A = nn.Parameter( |
|
initialized_weights((rank, fan_in), num_adaptions, init="kaiming") |
|
) |
|
self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank))) |
|
elif layer_type == "embedding": |
|
self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank))) |
|
self.lora_B = nn.Parameter( |
|
initialized_weights( |
|
(rank, fan_out), num_adaptions=num_adaptions, init="normal" |
|
) |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
self.lora_alpha, self.rank = lora_alpha, rank |
|
self.scaling = lora_alpha / rank |
|
self.lora_dropout = ( |
|
nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x |
|
) |
|
self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x |
|
self.register_buffer( |
|
"lora_dropout_mask", |
|
torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype), |
|
persistent=False, |
|
) |
|
self.forward_fn = lambda x: x |
|
self.current_task = None |
|
|
|
def _dropout(self, A): |
|
|
|
return A * self.lora_dropout(self.lora_dropout_mask) |
|
|
|
def lora_forward(self, X): |
|
assert self.current_task is not None |
|
return ( |
|
X |
|
+ torch.matmul( |
|
*self.swap( |
|
( |
|
self.lora_B[self.current_task], |
|
self.dropout_fn(self.lora_A[self.current_task]), |
|
) |
|
) |
|
).view(X.shape) |
|
* self.scaling |
|
) |
|
|
|
def forward(self, X): |
|
return self.forward_fn(X) |
|
|
|
def select_task(self, task=None): |
|
self.current_task = task |
|
if task is None: |
|
self.forward_fn = lambda x: x |
|
else: |
|
self.forward_fn = self.lora_forward |
|
|
|
@classmethod |
|
def from_linear( |
|
cls, |
|
layer: nn.Module, |
|
num_adaptions: int = 1, |
|
rank: int = 4, |
|
lora_dropout_p: float = 0.0, |
|
lora_alpha: int = 1, |
|
): |
|
assert isinstance(layer, nn.Linear) |
|
fan_out, fan_in = layer.weight.shape |
|
return cls( |
|
fan_in, |
|
fan_out, |
|
num_adaptions=num_adaptions, |
|
layer_type="linear", |
|
rank=rank, |
|
lora_dropout_p=lora_dropout_p, |
|
lora_alpha=lora_alpha, |
|
) |
|
|
|
@classmethod |
|
def from_embedding( |
|
cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1 |
|
): |
|
assert isinstance(layer, nn.Embedding) |
|
fan_in, fan_out = layer.weight.shape |
|
return cls( |
|
fan_in, |
|
fan_out, |
|
num_adaptions=num_adaptions, |
|
layer_type="embedding", |
|
rank=rank, |
|
lora_dropout_p=lora_dropout_p, |
|
lora_alpha=lora_alpha, |
|
) |
|
|
|
@classmethod |
|
def add_to_layer( |
|
cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1 |
|
): |
|
if isinstance(layer, nn.Linear): |
|
parametrize.register_parametrization( |
|
layer, |
|
"weight", |
|
cls.from_linear( |
|
layer, |
|
num_adaptions=num_adaptions, |
|
rank=rank, |
|
lora_dropout_p=lora_dropout_p, |
|
lora_alpha=lora_alpha, |
|
), |
|
) |
|
elif isinstance(layer, nn.Embedding): |
|
parametrize.register_parametrization( |
|
layer, |
|
"weight", |
|
cls.from_embedding( |
|
layer, |
|
num_adaptions=num_adaptions, |
|
rank=rank, |
|
lora_dropout_p=lora_dropout_p, |
|
lora_alpha=lora_alpha, |
|
), |
|
) |
|
|
|
@classmethod |
|
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None): |
|
if isinstance(layer, LoRAParametrization): |
|
layer.select_task(task_idx) |
|
|
|
|
|
class BertLoRA(BertPreTrainedModel): |
|
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1): |
|
super().__init__(config) |
|
if bert is None: |
|
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer) |
|
else: |
|
self.bert = bert |
|
self._register_lora(num_adaptions) |
|
for name, param in super().named_parameters(): |
|
if "lora" not in name: |
|
param.requires_grad_(False) |
|
self.select_task(0) |
|
|
|
@classmethod |
|
def from_bert(cls, *args, num_adaptions=1, **kwargs): |
|
bert = BertModel.from_pretrained(*args, **kwargs) |
|
config = JinaBertConfig.from_pretrained(*args, **kwargs) |
|
return cls(config, bert=bert, num_adaptions=num_adaptions) |
|
|
|
|
|
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): |
|
self.apply( |
|
partial( |
|
LoRAParametrization.add_to_layer, |
|
num_adaptions=num_adaptions, |
|
rank=rank, |
|
lora_dropout_p=lora_dropout_p, |
|
lora_alpha=lora_alpha, |
|
) |
|
) |
|
|
|
def select_task(self, task_idx: Union[None, int]): |
|
self.apply( |
|
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx) |
|
) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.bert(*args, **kwargs) |
|
|
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: |
|
for _, param in self.named_parameters(recurse=recurse): |
|
yield param |
|
|
|
def named_parameters( |
|
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
|
) -> Iterator[Tuple[str, Parameter]]: |
|
for name, param in super().named_parameters( |
|
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate |
|
): |
|
if "lora" in name: |
|
yield name, param |
|
|