|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from transformers.activations import ACT2FN |
|
|
|
import torch.nn.functional as F |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, activation, input_size, intermediate_size, output_size): |
|
super().__init__() |
|
self.input_size = input_size |
|
self.intermediate_size = intermediate_size |
|
self.output_size = output_size |
|
|
|
self.gate_proj = nn.Linear(input_size, intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(input_size, intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(intermediate_size, output_size, bias=False) |
|
self.act_fn = ACT2FN[activation] |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
RMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
class DecoupledEmbedding(nn.Embedding): |
|
|
|
""" |
|
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. |
|
In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. |
|
If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_embeddings, |
|
num_additional_embeddings, |
|
embedding_dim, |
|
partially_freeze=False, |
|
device=None, |
|
dtype=None, |
|
padding_idx=None, |
|
**kwargs, |
|
) -> None: |
|
""" |
|
num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. |
|
partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. |
|
|
|
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. |
|
""" |
|
if padding_idx is not None and padding_idx > num_embeddings: |
|
raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") |
|
super().__init__( |
|
num_embeddings=num_embeddings, |
|
embedding_dim=embedding_dim, |
|
device=device, |
|
dtype=dtype, |
|
padding_idx=padding_idx, |
|
**kwargs, |
|
) |
|
self.num_embeddings = num_embeddings |
|
self.padding_idx = padding_idx |
|
self.num_additional_embeddings = num_additional_embeddings |
|
self.partially_freeze = partially_freeze |
|
|
|
if partially_freeze: |
|
self.weight.requires_grad_(False) |
|
|
|
if self.num_additional_embeddings > 0: |
|
self.additional_embedding = nn.Embedding( |
|
num_embeddings=self.num_additional_embeddings, |
|
embedding_dim=embedding_dim, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
def forward(self, input_ids): |
|
""" |
|
we have 2 embeddings, with different indices - one pretrained self.weight and another |
|
self.additional_embedding.weight that is being trained. |
|
|
|
in order to make a lookup of the input ids, we: |
|
1. find out the indices of the entries belonging to the 2nd embedding |
|
2. extract those values while subtracting the size of the first embedding (num_embeddings), |
|
since the 2nd embedding starts from 0 and not num_embeddings |
|
3. perform the 2nd embedding lookup |
|
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index |
|
5. perform the 1st embedding lookup |
|
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup |
|
|
|
note: for the 1st embedding lookup we could have looked up only the low indices and not do |
|
the padding, but then we have to create a new tensor and populate it with 2 tensors that are |
|
spread out across various indices - i.e. not a simple concat - I haven't benchmarked the |
|
complex case if it's any faster, given that seqlens are usually relatively short it's |
|
probably not faster or if faster not by much - but might be a good idea to measure. |
|
|
|
""" |
|
if self.num_additional_embeddings == 0: |
|
return self.additional_embedding(input_ids) |
|
|
|
|
|
input_ids = input_ids.clone() |
|
additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) |
|
input_ids_additional_vocab = input_ids[additional_vocab_indices] |
|
additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) |
|
|
|
|
|
input_ids[additional_vocab_indices] = 0 |
|
full_vector = F.embedding(input_ids, self.weight) |
|
|
|
|
|
full_vector[additional_vocab_indices] = additional_embeddings |
|
|
|
return full_vector |
|
|
|
def extra_repr(self) -> str: |
|
return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( |
|
self.num_embeddings, |
|
self.num_additional_embeddings, |
|
self.embedding_dim, |
|
self.partially_freeze, |
|
) |
|
|
|
@classmethod |
|
def from_pretrained(cls, embeddings, freeze=True, **kwargs): |
|
raise NotImplementedError |
|
|
|
|
|
class DecoupledLinear(nn.Linear): |
|
|
|
""" |
|
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. |
|
In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained. |
|
If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
out_additional_features: int = 0, |
|
bias: bool = True, |
|
partially_freeze: bool = True, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
""" |
|
out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`. |
|
partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear. |
|
""" |
|
super().__init__(in_features, out_features, bias, device, dtype) |
|
self.out_additional_features = out_additional_features |
|
self.partially_freeze = partially_freeze |
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
|
|
if partially_freeze: |
|
self.weight.requires_grad_(False) |
|
if bias: |
|
self.bias.requires_grad_(False) |
|
|
|
if out_additional_features > 0: |
|
self.additional_fc = nn.Linear( |
|
in_features=in_features, |
|
out_features=out_additional_features, |
|
bias=bias, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
output = F.linear(input, self.weight, self.bias) |
|
|
|
if self.out_additional_features > 0: |
|
additional_features = self.additional_fc(input) |
|
output = torch.cat((output, additional_features), -1) |
|
|
|
return output |
|
|
|
def extra_repr(self) -> str: |
|
"""Overwriting `nn.Linear.extra_repr` to include new parameters.""" |
|
return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( |
|
self.in_features, |
|
self.out_features, |
|
self.out_additional_features, |
|
self.bias is not None, |
|
self.partially_freeze, |
|
) |
|
|