| import torch |
| from torch import nn |
| from typing import Optional |
| from dataclasses import dataclass |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import ModelOutput |
| from .attention import MultiHeadAttention, MultiHeadPAttention, PAttention, LayerNorm |
| from .mlp import swiglu_ln_ffn, intermediate_correction_fn |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| n_heads: int, |
| expansion_ratio: float = 8 / 3, |
| dropout: float = 0.1, |
| rotary: bool = False, |
| use_bias: bool = False, |
| ): |
| super().__init__() |
| self.attn = MultiHeadAttention(hidden_size, n_heads, rotary) |
| self.ffn = swiglu_ln_ffn(hidden_size, expansion_ratio, dropout, use_bias) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| x = self.attn(x, attention_mask) + x |
| x = self.ffn(x) + x |
| return x |
| |
|
|
| class Transformer(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| n_heads: int, |
| n_layers: int, |
| expansion_ratio: float = 8 / 3, |
| dropout: float = 0.1, |
| rotary: bool = False, |
| use_bias: bool = False |
| ): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| TransformerBlock(hidden_size, n_heads, expansion_ratio, dropout, rotary, use_bias) for _ in range(n_layers) |
| ]) |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| batch_size, seq_len, _ = x.shape |
| if attention_mask is not None and attention_mask.ndim == 2: |
| attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool() |
| for layer in self.layers: |
| x = layer(x, attention_mask) |
| return x |
|
|
|
|
| class TokenFormerBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| n_heads: int, |
| expansion_ratio: float = 8 / 3, |
| dropout: float = 0.1, |
| rotary: bool = False, |
| ): |
| super().__init__() |
| self.ln1 = LayerNorm(hidden_size) |
| self.attn = MultiHeadPAttention( |
| hidden_size=hidden_size, |
| n_heads=n_heads, |
| n_tokens=hidden_size, |
| dropout=dropout, |
| rotary=rotary, |
| ) |
| self.ln2 = LayerNorm(hidden_size) |
| self.ffn = PAttention( |
| hidden_size=hidden_size, |
| n_tokens=intermediate_correction_fn(expansion_ratio, hidden_size), |
| dropout=dropout, |
| ) |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| x = self.attn(self.ln1(x), attention_mask) + x |
| x = self.ffn(self.ln2(x)) + x |
| return x |
|
|
|
|
| class TokenFormer(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| n_heads: int, |
| n_layers: int, |
| expansion_ratio: float = 8 / 3, |
| dropout: float = 0.1, |
| rotary: bool = False, |
| use_bias: bool = False |
| ): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| TokenFormerBlock(hidden_size, n_heads, expansion_ratio, dropout, rotary) for _ in range(n_layers) |
| ]) |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| for layer in self.layers: |
| x = layer(x, attention_mask) |
| return x |
|
|
|
|
| class TransformerConfig(PretrainedConfig): |
| model_type = "transformer" |
| def __init__( |
| self, |
| hidden_size: int = 512, |
| n_heads: int = 8, |
| n_layers: int = 12, |
| vocab_size: int = 32000, |
| expansion_ratio: float = 8 / 3, |
| dropout: float = 0.1, |
| rotary: bool = True, |
| attn_implementation: str = 'sdpa', |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.hidden_size = hidden_size |
| self.n_heads = n_heads |
| self.n_layers = n_layers |
| self.expansion_ratio = expansion_ratio |
| self.dropout = dropout |
| self.rotary = rotary |
| self.vocab_size = vocab_size |
| self.attn_implementation = attn_implementation |
| |
| @dataclass |
| class TransformerOutput(ModelOutput): |
| """Output type for ESM++ models.""" |
| loss: Optional[torch.Tensor] = None |
| logits: Optional[torch.Tensor] = None |
| last_hidden_state: Optional[torch.Tensor] = None |
|
|
|
|
| class TransformerForMaskedLM(PreTrainedModel): |
| config_class = TransformerConfig |
| all_tied_weights_keys = {} |
| def __init__(self, config: TransformerConfig): |
| super().__init__(config) |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.transformer = Transformer( |
| hidden_size=config.hidden_size, |
| n_heads=config.n_heads, |
| n_layers=config.n_layers, |
| expansion_ratio=config.expansion_ratio, |
| dropout=config.dropout, |
| rotary=config.rotary, |
| ) |
| self.lm_head = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size), |
| nn.GELU(), |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, config.vocab_size), |
| ) |
| self.ce_loss = nn.CrossEntropyLoss() |
| self.vocab_size = config.vocab_size |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| return_preds: bool = True, |
| ) -> torch.Tensor: |
| x = self.embeddings(input_ids) |
| x = self.transformer(x, attention_mask) |
| logits = self.lm_head(x) |
| loss = None |
| if labels is not None: |
| loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1)) |
| return TransformerOutput( |
| loss=loss, |
| logits=logits.argmax(dim=-1) if return_preds else logits, |
| last_hidden_state=x, |
| ) |
|
|