abpt / src /model /testformer.py
Search
feat: add testformer wikitext combo runner
742c943
from __future__ import annotations
import torch
import torch.nn as nn
from src.model.testformer_block import TestFormerBlock
from src.model.testformer_config import TestFormerConfig
from src.model.testformer_loss import testformer_lm_loss
class TestFormerLM(nn.Module):
def __init__(
self,
cfg: TestFormerConfig,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.cfg = cfg
factory_kwargs = {"device": device, "dtype": dtype}
self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.d_model, **factory_kwargs)
self.embedding_dropout = nn.Dropout(cfg.emb_dropout)
self.blocks = nn.ModuleList(
[
TestFormerBlock(cfg, layer_idx=layer_idx, device=device, dtype=dtype)
for layer_idx in range(cfg.n_layers)
]
)
self.final_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False, **factory_kwargs)
if cfg.tie_input_output_embeddings:
self.lm_head.weight = self.token_embedding.weight
def parameter_count(self) -> int:
return sum(parameter.numel() for parameter in self.parameters())
def body_parameter_count(self) -> int:
embedding_params = self.token_embedding.weight.numel()
output_params = 0 if self.cfg.tie_input_output_embeddings else self.lm_head.weight.numel()
return self.parameter_count() - embedding_params - output_params
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
) -> dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]]:
hidden = self.embedding_dropout(self.token_embedding(input_ids))
layer_outputs: list[torch.Tensor] = [hidden]
block_outputs: list[dict[str, torch.Tensor | dict[str, torch.Tensor]]] = []
attention_outputs: list[dict[str, torch.Tensor]] = []
ffn_outputs: list[dict[str, torch.Tensor]] = []
for block in self.blocks:
block_out = block(hidden)
hidden = block_out["hidden"]
layer_outputs.append(hidden)
block_outputs.append(block_out)
attention_outputs.append(block_out["attention"])
ffn_outputs.append(block_out["ffn"])
hidden = self.final_norm(hidden)
logits = self.lm_head(hidden)
output: dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]] = {
"hidden": hidden,
"logits": logits,
"layer_outputs": layer_outputs,
"block_outputs": block_outputs,
"attention_outputs": attention_outputs,
"ffn_outputs": ffn_outputs,
}
if targets is not None:
loss_dict = testformer_lm_loss(logits, targets)
component_losses = loss_dict["component_losses"]
output["loss"] = loss_dict["total_loss"]
output["ce_loss"] = component_losses["ce_loss"]
output["component_losses"] = component_losses
return output