| """ |
| HuggingFace PreTrainedModel wrapper for InterpGPT / TaskGPT. |
| |
| Weights map 1:1 to the original gpt_model.TaskGPT state dict, so the same |
| .pt checkpoints produced during Phase 1 load here without remapping. |
| |
| Usage (after upload): |
| from transformers import AutoModel, AutoTokenizer |
| model = AutoModel.from_pretrained("connaaa/interpgpt-standard-23M", |
| trust_remote_code=True) |
| # Or for the analysis pipeline: |
| from transformer_lens import HookedTransformer |
| hooked = HookedTransformer.from_pretrained("connaaa/interpgpt-standard-23M", |
| hf_model=model, |
| ...) |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
|
|
| from .configuration_interpgpt import InterpGPTConfig |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, d_model: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(d_model)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return x * norm * self.weight |
|
|
|
|
| class RotaryPositionalEncoding(nn.Module): |
| def __init__(self, d_model: int, max_seq_len: int = 512, base: float = 10000.0): |
| super().__init__() |
| assert d_model % 2 == 0 |
| inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model)) |
| self.register_buffer("inv_freq", inv_freq) |
| t = torch.arange(max_seq_len, dtype=torch.float) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| self.register_buffer("cos_cached", freqs.cos()) |
| self.register_buffer("sin_cached", freqs.sin()) |
|
|
| def forward(self, seq_len: int): |
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len] |
|
|
|
|
| def apply_rotary_emb(x, cos, sin): |
| d_half = x.shape[-1] // 2 |
| x1, x2 = x[..., :d_half], x[..., d_half:] |
| cos = cos[: x.shape[2]].unsqueeze(0).unsqueeze(0) |
| sin = sin[: x.shape[2]].unsqueeze(0).unsqueeze(0) |
| out1 = x1 * cos - x2 * sin |
| out2 = x2 * cos + x1 * sin |
| return torch.cat([out1, out2], dim=-1) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: InterpGPTConfig): |
| super().__init__() |
| assert config.d_model % config.n_heads == 0 |
| self.n_heads = config.n_heads |
| self.head_dim = config.d_model // config.n_heads |
| self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias) |
| self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) |
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
| self.rope = RotaryPositionalEncoding(self.head_dim, config.max_seq_len) |
| mask = torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)) |
| self.register_buffer("causal_mask", mask.view(1, 1, config.max_seq_len, config.max_seq_len)) |
|
|
| def forward(self, x, kv_cache=None): |
| B, T, D = x.shape |
| qkv = self.qkv(x) |
| q, k, v = qkv.chunk(3, dim=-1) |
| q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| cos, sin = self.rope(T) |
| q = apply_rotary_emb(q, cos, sin) |
| k = apply_rotary_emb(k, cos, sin) |
| if kv_cache is not None: |
| if "k" in kv_cache: |
| k = torch.cat([kv_cache["k"], k], dim=2) |
| v = torch.cat([kv_cache["v"], v], dim=2) |
| kv_cache["k"] = k |
| kv_cache["v"] = v |
| if hasattr(F, "scaled_dot_product_attention") and kv_cache is None: |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None, |
| dropout_p=self.attn_dropout.p if self.training else 0.0, |
| is_causal=True, |
| ) |
| else: |
| scale = 1.0 / math.sqrt(self.head_dim) |
| attn = torch.matmul(q, k.transpose(-2, -1)) * scale |
| T_k = k.size(2) |
| causal = self.causal_mask[:, :, T_k - T : T_k, :T_k] |
| attn = attn.masked_fill(causal == 0, float("-inf")) |
| attn = F.softmax(attn, dim=-1) |
| attn = self.attn_dropout(attn) |
| out = torch.matmul(attn, v) |
| out = out.transpose(1, 2).contiguous().view(B, T, D) |
| return self.resid_dropout(self.out_proj(out)) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, config: InterpGPTConfig): |
| super().__init__() |
| hidden = int(2 * config.d_ff / 3) |
| hidden = 64 * ((hidden + 63) // 64) |
| self.gate_proj = nn.Linear(config.d_model, hidden, bias=config.bias) |
| self.up_proj = nn.Linear(config.d_model, hidden, bias=config.bias) |
| self.down_proj = nn.Linear(hidden, config.d_model, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config: InterpGPTConfig): |
| super().__init__() |
| self.ln1 = RMSNorm(config.d_model) |
| self.attn = CausalSelfAttention(config) |
| self.ln2 = RMSNorm(config.d_model) |
| self.ffn = FeedForward(config) |
|
|
| def forward(self, x, kv_cache=None): |
| x = x + self.attn(self.ln1(x), kv_cache) |
| x = x + self.ffn(self.ln2(x)) |
| return x |
|
|
|
|
| class InterpGPTModel(PreTrainedModel): |
| """ |
| HF-wrapped InterpGPT / TaskGPT. State dict parameter names match the |
| original gpt_model.TaskGPT exactly so Phase 1 .pt checkpoints load |
| via state_dict without remapping. |
| """ |
| config_class = InterpGPTConfig |
| base_model_prefix = "interpgpt" |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config: InterpGPTConfig): |
| super().__init__(config) |
| self.config = config |
| self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id) |
| self.drop = nn.Dropout(config.dropout) |
| self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) |
| self.ln_final = RMSNorm(config.d_model) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.lm_head.weight = self.token_embedding.weight |
| self.post_init() |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.padding_idx is not None: |
| nn.init.zeros_(module.weight[module.padding_idx]) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None, **kwargs): |
| B, T = input_ids.shape |
| x = self.drop(self.token_embedding(input_ids)) |
| for block in self.blocks: |
| x = block(x) |
| x = self.ln_final(x) |
| logits = self.lm_head(x) |
| output = {"logits": logits} |
| if labels is not None: |
| shift_logits = logits[:, :-1].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, self.config.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=self.config.pad_id, |
| reduction="none", |
| ).view(B, T - 1) |
| if loss_mask is not None: |
| shift_mask = loss_mask[:, 1:].contiguous().float() |
| loss = (loss * shift_mask).sum() / shift_mask.sum().clamp(min=1.0) |
| else: |
| loss = loss.mean() |
| output["loss"] = loss |
| return output |
|
|