| | import math |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutput |
| |
|
| | from .configuration_binaryllm import BinaryLLMConfig |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | """ |
| | Sinusoidal positional encoding, stocké en fp32, |
| | puis casté au dtype de x à chaque forward. |
| | """ |
| |
|
| | def __init__(self, d_model: int, max_len: int) -> None: |
| | super().__init__() |
| | pe = torch.zeros(max_len, d_model, dtype=torch.float32) |
| | position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, d_model, 2, dtype=torch.float32) * (-torch.log(torch.tensor(10000.0)) / d_model) |
| | ) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | pe = pe.unsqueeze(0) |
| | self.register_buffer("pe", pe, persistent=False) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | t = x.size(1) |
| | pe = self.pe[:, :t, :] |
| | pe = pe.to(device=x.device, dtype=x.dtype) |
| | return x + pe |
| |
|
| |
|
| | @dataclass |
| | class _InnerCfg: |
| | block_size: int |
| | embed_dim: int |
| | vocab_size: int |
| | num_heads: int |
| | num_layers: int |
| | ff_hidden_dim: int |
| | dropout: float |
| | layernorm_dim: Optional[int] = None |
| | head_dim: Optional[int] = None |
| |
|
| |
|
| | class TinyTransformerLM(nn.Module): |
| | def __init__(self, cfg: _InnerCfg) -> None: |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | vocab_size = cfg.vocab_size |
| | self.tok_embed = nn.Embedding(vocab_size, cfg.embed_dim) |
| | self.pos_encoding = PositionalEncoding(cfg.embed_dim, cfg.block_size) |
| |
|
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=cfg.embed_dim, |
| | nhead=cfg.num_heads, |
| | dim_feedforward=cfg.ff_hidden_dim, |
| | dropout=cfg.dropout, |
| | activation="gelu", |
| | batch_first=True, |
| | ) |
| | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.num_layers) |
| |
|
| | ln_dim = cfg.layernorm_dim or cfg.embed_dim |
| | head_dim = cfg.head_dim or ln_dim |
| |
|
| | self.pre_ln_proj: Optional[nn.Linear] = None |
| | if ln_dim != cfg.embed_dim: |
| | self.pre_ln_proj = nn.Linear(cfg.embed_dim, ln_dim) |
| |
|
| | self.ln = nn.LayerNorm(ln_dim) |
| |
|
| | self.head_pre: Optional[nn.Linear] = None |
| | if head_dim != ln_dim: |
| | self.head_pre = nn.Linear(ln_dim, head_dim) |
| |
|
| | self.head = nn.Linear(head_dim, vocab_size, bias=False) |
| |
|
| | |
| | if self.pre_ln_proj is None and self.head_pre is None and head_dim == cfg.embed_dim: |
| | self.head.weight = self.tok_embed.weight |
| |
|
| | causal = torch.triu(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool), diagonal=1) |
| | self.register_buffer("causal_mask", causal, persistent=False) |
| |
|
| | def forward(self, tokens: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| | x = self.tok_embed(tokens) |
| | x = self.pos_encoding(x) |
| |
|
| | seq_len = tokens.size(1) |
| | attn_mask = self.causal_mask[:seq_len, :seq_len].to(device=tokens.device) |
| |
|
| | if padding_mask is not None: |
| | padding_mask = padding_mask[:, :seq_len].to(device=tokens.device, dtype=torch.bool) |
| |
|
| | x = self.encoder(x, mask=attn_mask, src_key_padding_mask=padding_mask) |
| |
|
| | if self.pre_ln_proj is not None: |
| | x = self.pre_ln_proj(x) |
| |
|
| | x = self.ln(x) |
| |
|
| | if self.head_pre is not None: |
| | x = self.head_pre(x) |
| |
|
| | return self.head(x) |
| |
|
| |
|
| | class BinaryLLMForCausalLM(PreTrainedModel): |
| | config_class = BinaryLLMConfig |
| | main_input_name = "input_ids" |
| |
|
| | def __init__(self, config: BinaryLLMConfig): |
| | super().__init__(config) |
| |
|
| | inner = _InnerCfg( |
| | block_size=int(config.max_position_embeddings), |
| | embed_dim=int(config.hidden_size), |
| | vocab_size=int(config.vocab_size), |
| | num_heads=int(config.num_attention_heads), |
| | num_layers=int(config.num_hidden_layers), |
| | ff_hidden_dim=int(config.intermediate_size), |
| | dropout=float(getattr(config, "dropout", 0.0)), |
| | layernorm_dim=None, |
| | head_dim=None, |
| | ) |
| | self.model = TinyTransformerLM(inner) |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> CausalLMOutput: |
| | padding_mask = None |
| | if attention_mask is not None: |
| | padding_mask = ~attention_mask.to(torch.bool) |
| |
|
| | logits = self.model(input_ids, padding_mask=padding_mask) |
| |
|
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[:, :-1, :].contiguous() |
| | shift_labels = labels[:, 1:].contiguous() |
| | loss = F.cross_entropy( |
| | shift_logits.view(-1, self.config.vocab_size), |
| | shift_labels.view(-1), |
| | ignore_index=-100, |
| | ) |
| |
|
| | return CausalLMOutput(loss=loss, logits=logits) |
| |
|