import math from typing import Optional import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F from transformers import PreTrainedModel from transformers.cache_utils import DynamicCache from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from .config import GPTConfig CONTROL_TENSOR_NAME_PATTERNS = ( "scale", "gate", "gain", "norm", "ln_", "rms", ) class CastedLinear(nn.Linear): """Store linear params in FP32, cast to activation dtype for matmul.""" def forward(self, x: Tensor) -> Tensor: weight = self.weight.to(dtype=x.dtype) bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None return F.linear(x, weight, bias) def restore_fp32_params(model: nn.Module) -> None: """Keep linear weights and control params in FP32 after dtype conversion.""" for module in model.modules(): if isinstance(module, CastedLinear): module.float() for name, param in model.named_parameters(): if ( param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ) and param.dtype != torch.float32: param.data = param.data.float() class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) return (x.float() * rms).to(dtype=x.dtype) * self.weight.to(dtype=x.dtype) def build_rope_inv_freq(head_dim, theta=2500.0): return 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) def precompute_rope_cos_sin(head_dim, seq_len, theta=2500.0): freqs = build_rope_inv_freq(head_dim, theta) t = torch.arange(seq_len, dtype=torch.float32) freqs = torch.outer(t, freqs) return freqs.cos(), freqs.sin() def _apply_rope(x, cos, sin): x_float = x.float() x_pair = x_float.reshape(*x_float.shape[:-1], -1, 2) even = x_pair[..., 0] odd = x_pair[..., 1] cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) x_rot = torch.stack((even * cos - odd * sin, even * sin + odd * cos), dim=-1) return x_rot.flatten(-2).type_as(x) def apply_rotary_emb(q, k, freqs_cis): cos, sin = freqs_cis return _apply_rope(q, cos, sin), _apply_rope(k, cos, sin) class GPTAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx self.n_head = config.num_attention_heads self.n_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim self.n_rep = self.n_head // self.n_kv_heads self.xsa_projection = config.xsa_projection self.q_proj = CastedLinear(config.hidden_size, self.n_head * self.head_dim, bias=False) self.k_proj = CastedLinear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False) self.v_proj = CastedLinear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False) self.o_proj = CastedLinear(self.n_head * self.head_dim, config.hidden_size, bias=False) def _xsa_efficient(self, y: Tensor, v_current: Tensor) -> Tensor: B, H, T, D = y.shape Hkv = v_current.size(1) group = H // Hkv y_g = y.reshape(B, Hkv, group, T, D) v_n = F.normalize(v_current, dim=-1).unsqueeze(2) proj = (y_g * v_n).sum(dim=-1, keepdim=True) * v_n return (y_g - proj).reshape(B, H, T, D) def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None): B, T, _ = x.size() q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k_current = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v_current = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) q, k_current = apply_rotary_emb(q, k_current, freqs_cis) if past_key_value is not None: k, v = past_key_value.update(k_current, v_current, self.layer_idx) else: k, v = k_current, v_current S = k.size(2) is_causal = past_key_value is None or past_key_value.get_seq_length(self.layer_idx) == T attn_mask = None if attention_mask is not None: key_pad = attention_mask.to(torch.bool)[:, None, None, :] if is_causal and T > 1: causal = torch.ones(T, S, dtype=torch.bool, device=x.device).tril(diagonal=S - T) attn_mask = key_pad & causal[None, None, :, :] else: attn_mask = key_pad.expand(B, 1, T, S) is_causal = False y = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=is_causal, enable_gqa=(self.n_kv_heads != self.n_head), ) if self.xsa_projection: y = self._xsa_efficient(y, v_current) y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) return self.o_proj(y) class GPTMLP(nn.Module): def __init__(self, config): super().__init__() self.w_gate = CastedLinear(config.hidden_size, config.intermediate_size, bias=False) self.w_up = CastedLinear(config.hidden_size, config.intermediate_size, bias=False) self.w_down = CastedLinear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x): return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) class GPTBlock(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = GPTAttention(config, layer_idx) self.ln_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = GPTMLP(config) def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None): x = x + self.attn(self.ln_1(x), freqs_cis, past_key_value, use_cache, attention_mask=attention_mask) x = x + self.mlp(self.ln_2(x)) return x class GPTPreTrainedModel(PreTrainedModel): config_class = GPTConfig base_model_prefix = "transformer" supports_gradient_checkpointing = False def _init_weights(self, module): std = self.config.hidden_size ** -0.5 if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=std) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=std) class GPTForCausalLM(GPTPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) self.config = config self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.hidden_size), h=nn.ModuleList([GPTBlock(config, i) for i in range(config.num_hidden_layers)]), ln_f=RMSNorm(config.hidden_size, eps=config.rms_norm_eps), )) self.lm_head = CastedLinear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.transformer["wte"].weight self._freqs_cis_cache = None self.post_init() restore_fp32_params(self) def _apply(self, fn): module = super()._apply(fn) restore_fp32_params(self) return module def get_input_embeddings(self): return self.transformer["wte"] def set_input_embeddings(self, value): self.transformer["wte"] = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): if past_key_values is not None and past_key_values.get_seq_length() > 0: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": True, } def _get_freqs_cis(self, seq_len, device): cache = self._freqs_cis_cache if cache is None or cache[0].device != device or cache[0].size(0) < seq_len: cache = tuple( tensor.to(device) for tensor in precompute_rope_cos_sin(self.config.head_dim, seq_len, self.config.rope_theta) ) if torch.is_inference_mode_enabled(): return cache[0][:seq_len], cache[1][:seq_len] self._freqs_cis_cache = cache return cache[0][:seq_len], cache[1][:seq_len] def forward( self, input_ids, attention_mask=None, labels=None, past_key_values: Optional[DynamicCache] = None, use_cache=False, **kwargs, ): B, T = input_ids.size() if use_cache and past_key_values is None: past_key_values = DynamicCache() past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 x = self.transformer["wte"](input_ids) freqs_cis = self._get_freqs_cis(past_len + T, input_ids.device)[past_len:] for block in self.transformer["h"]: x = block(x, freqs_cis, past_key_values if use_cache else None, use_cache, attention_mask=attention_mask) x = self.transformer["ln_f"](x) logits = self.lm_head(x) loss = None if labels is not None: if getattr(self.config, "labels_are_shifted", False): loss = F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), labels.reshape(-1)) else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.float().view(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values if use_cache else None, )