# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from transformers.activations import ACT2FN import torch.nn.functional as F class MLP(nn.Module): def __init__(self, activation, input_size, intermediate_size, output_size): super().__init__() self.input_size = input_size self.intermediate_size = intermediate_size self.output_size = output_size self.gate_proj = nn.Linear(input_size, intermediate_size, bias=False) self.up_proj = nn.Linear(input_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, output_size, bias=False) self.act_fn = ACT2FN[activation] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class DecoupledEmbedding(nn.Embedding): # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding """ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. """ def __init__( self, num_embeddings, num_additional_embeddings, embedding_dim, partially_freeze=False, device=None, dtype=None, padding_idx=None, **kwargs, ) -> None: """ num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. """ if padding_idx is not None and padding_idx > num_embeddings: raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") super().__init__( num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype, padding_idx=padding_idx, **kwargs, ) self.num_embeddings = num_embeddings self.padding_idx = padding_idx self.num_additional_embeddings = num_additional_embeddings self.partially_freeze = partially_freeze if partially_freeze: self.weight.requires_grad_(False) if self.num_additional_embeddings > 0: self.additional_embedding = nn.Embedding( num_embeddings=self.num_additional_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype, ) def forward(self, input_ids): """ we have 2 embeddings, with different indices - one pretrained self.weight and another self.additional_embedding.weight that is being trained. in order to make a lookup of the input ids, we: 1. find out the indices of the entries belonging to the 2nd embedding 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd embedding starts from 0 and not num_embeddings 3. perform the 2nd embedding lookup 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index 5. perform the 1st embedding lookup 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are usually relatively short it's probably not faster or if faster not by much - but might be a good idea to measure. """ if self.num_additional_embeddings == 0: return self.additional_embedding(input_ids) # Clone so that we don't modify the original input_ids later on input_ids = input_ids.clone() additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) input_ids_additional_vocab = input_ids[additional_vocab_indices] additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) # for successful lookup replace input_ids with 0, the results of these will be discarded anyway input_ids[additional_vocab_indices] = 0 full_vector = F.embedding(input_ids, self.weight) # overwrite the records with high indices full_vector[additional_vocab_indices] = additional_embeddings return full_vector def extra_repr(self) -> str: return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( self.num_embeddings, self.num_additional_embeddings, self.embedding_dim, self.partially_freeze, ) @classmethod def from_pretrained(cls, embeddings, freeze=True, **kwargs): raise NotImplementedError class DecoupledLinear(nn.Linear): # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear """ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained. If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. """ def __init__( self, in_features: int, out_features: int, out_additional_features: int = 0, bias: bool = True, partially_freeze: bool = True, device=None, dtype=None, ) -> None: """ out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear. """ super().__init__(in_features, out_features, bias, device, dtype) self.out_additional_features = out_additional_features self.partially_freeze = partially_freeze self.in_features = in_features self.out_features = out_features if partially_freeze: self.weight.requires_grad_(False) if bias: self.bias.requires_grad_(False) if out_additional_features > 0: self.additional_fc = nn.Linear( in_features=in_features, out_features=out_additional_features, bias=bias, device=device, dtype=dtype, ) def forward(self, input: torch.Tensor) -> torch.Tensor: output = F.linear(input, self.weight, self.bias) if self.out_additional_features > 0: additional_features = self.additional_fc(input) output = torch.cat((output, additional_features), -1) return output def extra_repr(self) -> str: """Overwriting `nn.Linear.extra_repr` to include new parameters.""" return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( self.in_features, self.out_features, self.out_additional_features, self.bias is not None, self.partially_freeze, )