import torch from torch import Tensor, nn from torch.nn import Sequential from torch.utils.checkpoint import checkpoint, checkpoint_sequential from xformers.components.attention.utils import maybe_merge_masks from xformers.components import MultiHeadDispatch from xformers.components.attention import ScaledDotProduct from transformers import AutoTokenizer class RotaryEmbedding(nn.Module): def __init__( self, dim_per_head: int, max_seq_len: int = 4096, interpolation_ratio: float | None = 0.25, device=None, dtype=None, ): super().__init__() self.dim_per_head = dim_per_head self.max_seq_len = max_seq_len freqs = 1.0 / ( 10000 ** ( torch.arange(0, dim_per_head, 2, device=device, dtype=dtype).float() / 6 ) ) freqs = torch.repeat_interleave(freqs, 2) r = ( freqs * torch.arange(max_seq_len, device=device, dtype=dtype).float()[:, None] ) if interpolation_ratio is not None: r = r * interpolation_ratio r1 = r.cos() self.register_buffer("r1", r1) r2 = r.sin() self.register_buffer("r2", r2) aranged = torch.arange(dim_per_head, device=device, dtype=dtype) mask1 = torch.where( aranged % 2 == 1, aranged - 1, aranged + 1, ).float() self.register_buffer("mask1", mask1) mask2 = torch.where(aranged % 2 == 0, -1, 1).float() self.register_buffer("mask2", mask2) def forward(self, x: Tensor): """ Args: x (Tensor): input tensor. shape: (bs, seq_len, n_heads, dim_per_head) Returns: Tensor: input tensor with rotary embeddings. shape: (bs, seq_len, n_heads, dim_per_head) """ assert ( x.ndim == 4 ), "input must have 4 dimensions: (bs, n_heads, seq_len, dim_per_head)" assert x.shape[3] % 2 == 0, "dim_per_head must be divisible by 2" x = x.transpose(1, 2) return ( x * self.r1[None, : x.shape[1], None, :] + x[ :, :, :, self.mask1, ] * self.mask2 * self.r2[None, : x.shape[1], None, :] ).transpose(1, 2) def extra_repr(self) -> str: return f"dim_per_head={self.dim_per_head}, max_seq_len={self.max_seq_len}" class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-9): super().__init__() self.dim = dim self.trainable = nn.Parameter( data=torch.nn.init.normal_(torch.zeros((dim,))), requires_grad=True ) self.eps = eps def forward(self, x: Tensor): """ Args: x (Tensor): input tensor. shape: (bs, seq_len, embed_dim) Returns: Tensor: input tensor with rotary embeddings. shape: (bs, seq_len, embed_dim) """ assert x.ndim == 3, "input must have 3 dimensions: (bs, seq_len, embed_dim)" return ( x / torch.sqrt_(torch.mean(torch.square(x), dim=-1) + self.eps)[:, :, None] * self.trainable ) def extra_repr(self) -> str: return f"dim={self.dim}, eps={self.eps}" class SiLU(nn.Module): def __init__(self): super().__init__() def forward(self, x: Tensor): """ Args: x (Tensor): input """ return x * x.sigmoid() class SwiGLU(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.linear_inp1 = nn.Linear(dim, (8 * dim) // 3, bias=False) self.linear_inp2 = nn.Linear(dim, (8 * dim) // 3, bias=False) self.linear_out = nn.Linear((8 * dim) // 3, dim, bias=False) self.silu = SiLU() # nn.init.xavier_uniform_(self.linear_inp1.weight) # nn.init.xavier_uniform_(self.linear_inp2.weight) # nn.init.xavier_uniform_(self.linear_out.weight) def forward(self, x: Tensor): """ Args: x (Tensor): input tensor """ return self.linear_out(self.silu(self.linear_inp1(x)) * self.linear_inp2(x)) class MistralTokenizer(nn.Module): def __init__(self, max_length=1024, *args, **kwargs): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained( "mistralai/Mistral-7B-v0.1", *args, **kwargs ) self.tokenizer.add_special_tokens({"pad_token": ""}) self.special_tokens_ids = { token: id for token, id in zip( self.tokenizer.special_tokens_map.keys(), self.tokenizer.all_special_ids ) } self.max_length = max_length self.pad_token_id = self.tokenizer.pad_token_id def forward(self, text): return self.tokenizer( text, return_tensors="pt", return_attention_mask=False, max_length=self.max_length, truncation=True, padding=True, padding_side="right", ) def convert_ids_to_tokens(self, ids): return self.tokenizer.convert_ids_to_tokens(ids) def decode(self, x): return self.tokenizer.batch_decode(x) def __len__(self): return len(self.tokenizer) class MultiHeadAttention(nn.Module): def __init__( self, emb_size: int, n_heads: int, dropout: float = 0.0, use_rotary_embeddings: bool = False, bias_qkv: bool = False, bias_out: bool = False, ): super().__init__() self.emb_size = emb_size self.n_heads = n_heads assert ( self.emb_size % n_heads == 0 ), "Embedding size needs to be divisible by heads" self.head_dim = emb_size // n_heads self.use_rotary_embeddings = use_rotary_embeddings if self.use_rotary_embeddings: self.rotary_embed = RotaryEmbedding(self.head_dim) self.qkv = nn.Linear(emb_size, emb_size * 3, bias=bias_qkv) self.dropout = nn.Dropout(dropout) self.out = nn.Linear(emb_size, emb_size, bias=bias_out) self.scaling = self.head_dim**-0.5 def forward(self, x: Tensor, att_mask: Tensor = None): qkv = self.qkv(x).chunk(3, dim=-1) q, k, v = map( lambda t: t.reshape(x.shape[0], -1, self.n_heads, self.head_dim).transpose( 1, 2 ), qkv, ) # [batch_size, n_heads, seq_len, head_dim] if self.use_rotary_embeddings: q, k = self.rotary_embed(q), self.rotary_embed(k) dots = ( torch.matmul(q, k.transpose(-1, -2)) * self.scaling ) # [batch_size, n_heads, seq_len, seq_len] if att_mask is not None: dots = dots + att_mask attn = self.dropout(torch.softmax(dots, dim=-1)) out = ( torch.matmul(attn, v).transpose(1, 2).reshape(x.shape[0], -1, self.emb_size) ) out = self.out(out) return out class LLaMADecoderLayer(nn.Module): def __init__( self, emb_size: int, n_heads: int, dropout: float, ) -> None: super().__init__() self.emb_size = emb_size self.multihead_attn = MultiHeadDispatch( dim_model=emb_size, num_heads=n_heads, attention=ScaledDotProduct( dropout=dropout, ), bias=(False, False, False, False), use_rotary_embeddings=True, ) self.rmsnorm1 = nn.RMSNorm(emb_size, eps=1e-9) self.rmsnorm2 = nn.RMSNorm(emb_size, eps=1e-9) self.swiglu = SwiGLU(emb_size) self.n_heads = n_heads def forward(self, in_tuple) -> Tensor: """ Args: in_tuple (tuple[Tensor, Tensor, Tensor]): tuple, containing 3 tensors: x (Tensor): input tensor (bs, seq_len, dim) attn_mask (Tensor): attention mask (seq_len, seq_len) padding_mask (Tensor): padding mask (bs, seq_len) Returns: Tensor: output tensor """ assert len(in_tuple) == 2, "input tuple must have 2 elements" x, mask = in_tuple x = self.multihead_attn(self.rmsnorm1(x), att_mask=mask) + x return self.swiglu(self.rmsnorm2(x)) + x, mask class CustomAttentionLLaMaDecoder(LLaMADecoderLayer): def __init__( self, emb_size: int, n_heads: int, dropout: float, ) -> None: super().__init__(emb_size, n_heads, dropout) self.multihead_attn = MultiHeadAttention( emb_size=emb_size, n_heads=n_heads, bias_qkv=False, bias_out=False, use_rotary_embeddings=True, dropout=dropout, ) self.rmsnorm1 = RMSNorm(emb_size, eps=1e-9) self.rmsnorm2 = RMSNorm(emb_size, eps=1e-9) class LLaMaBase(nn.Module): def __init__( self, embed_dim: int = 512, n_layers: int = 2, n_heads: int = 8, dropout: int = 0.0, n_chckpnt_segments: int = 1, tokenizer=MistralTokenizer(), **kwargs, ): """ Args: n_feats (int): number of input features. n_class (int): number of classes. fc_hidden (int): number of hidden features. """ super().__init__() self.tokenizer = tokenizer self.vocab_len = len(tokenizer) self.n_heads = n_heads self.dropout = dropout self.n_layers = n_layers self.embed_dim = embed_dim self.n_segments = n_chckpnt_segments self.embed = nn.Embedding( self.vocab_len, embed_dim, padding_idx=self.tokenizer.pad_token_id ) self.head = nn.Linear(embed_dim, self.vocab_len, bias=False) def forward(self, src: Tensor, attn_mask: Tensor, pad_mask: Tensor, **batch): """ Model forward method. Args: tokenized (Tensor): input text. shape: (batch_size, seq_len) Returns: output (dict): output dict containing logits. """ raise NotImplementedError def __str__(self): """ Model prints with the number of parameters. """ all_parameters = sum([p.numel() for p in self.parameters()]) trainable_parameters = sum( [p.numel() for p in self.parameters() if p.requires_grad] ) embedding_parameters = sum([p.numel() for p in self.embed.parameters()]) result_info = super().__str__() result_info = result_info + f"\nAll parameters: {all_parameters}" result_info = result_info + f"\nTrainable parameters: {trainable_parameters}" result_info = ( result_info + f"\nWithout embedding: {trainable_parameters - embedding_parameters}" ) return result_info class CustomAttentionLLaMa(LLaMaBase): def __init__( self, embed_dim: int = 512, n_layers: int = 2, n_heads: int = 8, dropout: int = 0.0, n_chckpnt_segments: int = 1, tokenizer=MistralTokenizer(), **kwargs, ): """ Args: n_feats (int): number of input features. n_class (int): number of classes. fc_hidden (int): number of hidden features. """ super().__init__( embed_dim, n_layers, n_heads, dropout, n_chckpnt_segments, tokenizer, ) self.decoders = nn.Sequential( *[ CustomAttentionLLaMaDecoder( emb_size=embed_dim, n_heads=self.n_heads, dropout=dropout ) for _ in range(n_layers) ] ) self.rmsnorm = RMSNorm(embed_dim, eps=1e-9) def forward(self, src: Tensor, attn_mask: Tensor, pad_mask: Tensor, **batch): """ Model forward method. Args: tokenized (Tensor): input text. shape: (batch_size, seq_len) Returns: output (dict): output dict containing logits. """ x = self.embed(src) # embeds shape: [batch_size, seq_len, embed_dim] sizes = x.shape mask = maybe_merge_masks( attn_mask, pad_mask, sizes[0], sizes[1], self.n_heads ).view(x.shape[0], self.n_heads, sizes[1], sizes[1]) x, _ = checkpoint_sequential(self.decoders, self.n_segments, input=(x, mask)) # for decoder in self.decoders: # x, _, _ = decoder((x, attn_mask, pad_mask)) logits = self.head(self.rmsnorm(x)) return { "logits": logits.permute(0, 2, 1) } # logits shape: [batch_size, vocab_len, seq_len]