| | """
|
| | LM-eval harness wrapper for Circuit/Mirrored transformers.
|
| |
|
| | Usage:
|
| | # Single model
|
| | python -m circuits.bench --checkpoint circuits/checkpoints/mirrored/best.pt --gpu 0
|
| |
|
| | # Compare all architectures
|
| | python -m circuits.bench --compare --gpu 0
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import List
|
| | from tqdm import tqdm
|
| | from lm_eval.api.model import LM
|
| | from lm_eval.api.instance import Instance
|
| |
|
| | from .config import CircuitConfig
|
| | from .model import CircuitTransformer
|
| | from .mirrored import MirroredConfig, MirroredTransformer
|
| | from .graft_g2lu import load_g2lu_model
|
| | from .layers import build_word_start_table, compute_word_positions
|
| | from .data import get_tokenizer
|
| |
|
| | def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| | """Migrate checkpoint state_dict to match current model architecture.
|
| |
|
| | Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| | """
|
| | if any(k.startswith("_orig_mod.") for k in state_dict):
|
| | state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| |
|
| | model_keys = set(model.state_dict().keys())
|
| | ckpt_keys = set(state_dict.keys())
|
| |
|
| | missing = model_keys - ckpt_keys
|
| | unexpected = ckpt_keys - model_keys
|
| |
|
| | print(unexpected)
|
| |
|
| | if not missing and not unexpected:
|
| | return state_dict
|
| |
|
| | migrated = dict(state_dict)
|
| | migrations = []
|
| |
|
| |
|
| | for key in list(unexpected):
|
| | if ".ffn.gate_expand.weight" in key:
|
| | new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| | if new_key in missing:
|
| | migrated[new_key] = migrated.pop(key)
|
| | missing.discard(new_key)
|
| | unexpected.discard(key)
|
| | migrations.append(f" {key} → {new_key}")
|
| | if ".ffn.gate_compress.weight" in key:
|
| | new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| | if new_key in missing:
|
| | migrated[new_key] = migrated.pop(key)
|
| | missing.discard(new_key)
|
| | unexpected.discard(key)
|
| | migrations.append(f" {key} → {new_key}")
|
| |
|
| | if migrations:
|
| | print(f"State dict migration ({len(migrations)} keys renamed):")
|
| | for m in migrations:
|
| | print(m)
|
| |
|
| | still_missing = model_keys - set(migrated.keys())
|
| | if still_missing:
|
| | print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| | for k in sorted(still_missing):
|
| | print(f" {k}")
|
| |
|
| | return migrated
|
| |
|
| | def load_model(checkpoint_path: str, device: str = "cuda"):
|
| | """Load any circuit model from checkpoint with auto-detection."""
|
| | checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| |
|
| | model_type = checkpoint.get("model_type", "standard")
|
| | if model_type == "graft_g2lu":
|
| | model = load_g2lu_model(checkpoint_path, device=device)
|
| | model.eval()
|
| | n_layers = len(model.g2lu_mlps)
|
| | arch_name = f"G²LU Graft ({checkpoint['pretrained_name']}, {n_layers}L)"
|
| | config = model.model.config
|
| | return model, config, arch_name, model_type
|
| | elif model_type == "mirrored":
|
| | if checkpoint["config"].get("dual_gate_middle"):
|
| | checkpoint["config"].pop("dual_gate_middle")
|
| | config = MirroredConfig.from_dict(checkpoint["config"])
|
| | model = MirroredTransformer(config)
|
| | arch_name = f"Mirrored ({model.total_virtual_layers}L)"
|
| | else:
|
| | config = CircuitConfig.from_dict(checkpoint["config"])
|
| | model = CircuitTransformer(config)
|
| | arch_name = f"Standard ({config.num_layers}L)"
|
| |
|
| |
|
| | state_dict = checkpoint["model"]
|
| | state_dict = _migrate_state_dict(state_dict, model)
|
| | model.load_state_dict(state_dict)
|
| |
|
| | model = model.to(device).eval()
|
| | return model, config, arch_name, model_type
|
| |
|
| |
|
| | class CircuitLM(LM):
|
| | """LM-eval wrapper for Circuit transformer family."""
|
| |
|
| | def __init__(
|
| | self,
|
| | checkpoint: str,
|
| | device: str = "cuda",
|
| | batch_size: int = 1,
|
| | compile: bool = False,
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.model, self.config, self.arch_name, self.model_type = load_model(
|
| | checkpoint, device
|
| | )
|
| |
|
| | self._raw_model = self.model
|
| | if compile == True:
|
| | self.model = torch.compile(self.model)
|
| | print(" torch.compile: enabled")
|
| | _ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
|
| | _tok_name = _ckpt.get("tokenizer_name", "gpt2")
|
| | del _ckpt
|
| | self.tokenizer = get_tokenizer(_tok_name)
|
| | if self.tokenizer.pad_token is None:
|
| | self.tokenizer.pad_token = self.tokenizer.eos_token
|
| |
|
| | self._device = device
|
| | self._batch_size = batch_size
|
| |
|
| |
|
| | self._word_start_table = None
|
| | word_rope_dims = getattr(self.config, 'word_rope_dims', 0)
|
| | if word_rope_dims == 0 and isinstance(self.config, dict):
|
| | word_rope_dims = self.config.get('word_rope_dims', 0)
|
| | if word_rope_dims > 0:
|
| | self._word_start_table = build_word_start_table(
|
| | self.tokenizer, len(self.tokenizer)
|
| | ).to(device)
|
| | print(f" Word-position RoPE: {word_rope_dims} dims")
|
| |
|
| |
|
| | n_params = sum(p.numel() for p in self.model.parameters())
|
| | print(f" Architecture: {self.arch_name}")
|
| | print(f" Parameters: {n_params / 1e6:.1f}M")
|
| |
|
| | @property
|
| | def eot_token_id(self):
|
| | return self.tokenizer.eos_token_id
|
| |
|
| | @property
|
| | def max_length(self):
|
| | return getattr(self.config, "max_seq_len", None) or getattr(self.config, "max_position_embeddings", 512)
|
| |
|
| | @property
|
| | def max_gen_toks(self):
|
| | return 256
|
| |
|
| | @property
|
| | def batch_size(self):
|
| | return self._batch_size
|
| |
|
| | @property
|
| | def device(self):
|
| | return self._device
|
| |
|
| | def tok_encode(self, string: str) -> List[int]:
|
| | return self.tokenizer.encode(string, add_special_tokens=False)
|
| |
|
| | def tok_decode(self, tokens: List[int]) -> str:
|
| | return self.tokenizer.decode(tokens)
|
| |
|
| | def _model_call(self, input_ids: torch.Tensor):
|
| | with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16, enabled=self._device != "cpu"):
|
| | word_positions = None
|
| | if self._word_start_table is not None:
|
| | word_positions = compute_word_positions(input_ids, self._word_start_table)
|
| | output = self.model(input_ids, use_cache=False, word_positions=word_positions)
|
| | return output["logits"]
|
| |
|
| | def _loglikelihood_tokens(self, requests, disable_tqdm=False):
|
| | results = []
|
| | for context_enc, continuation_enc in requests:
|
| |
|
| | full_enc = context_enc + continuation_enc
|
| | if len(full_enc) > self.max_length:
|
| | excess = len(full_enc) - self.max_length
|
| | context_enc = context_enc[excess:]
|
| | full_enc = context_enc + continuation_enc
|
| |
|
| | input_ids = torch.tensor(
|
| | [full_enc], dtype=torch.long, device=self._device
|
| | )
|
| |
|
| | logits = self._model_call(input_ids)
|
| |
|
| | ctx_len = len(context_enc)
|
| | cont_logits = logits[:, ctx_len - 1 : -1, :]
|
| | cont_tokens = input_ids[:, ctx_len:]
|
| |
|
| | log_probs = F.log_softmax(cont_logits, dim=-1)
|
| | token_log_probs = log_probs.gather(
|
| | 2, cont_tokens.unsqueeze(-1)
|
| | ).squeeze(-1)
|
| |
|
| | total_log_prob = token_log_probs.sum().item()
|
| | is_greedy = (cont_logits.argmax(dim=-1) == cont_tokens).all().item()
|
| |
|
| | results.append((total_log_prob, is_greedy))
|
| |
|
| | return results
|
| |
|
| | def loglikelihood(
|
| | self, requests: List[Instance], disable_tqdm: bool = False
|
| | ) -> List[tuple]:
|
| | results = []
|
| | for request in tqdm(
|
| | requests, desc="loglikelihood", disable=disable_tqdm
|
| | ):
|
| | context, continuation = request.args
|
| |
|
| |
|
| |
|
| | context_enc = self.tok_encode(context)
|
| | full_enc = self.tok_encode(context + continuation)
|
| | continuation_enc = full_enc[len(context_enc):]
|
| | if not continuation_enc:
|
| |
|
| |
|
| | continuation_enc = self.tok_encode(continuation)
|
| | result = self._loglikelihood_tokens([(context_enc, continuation_enc)])
|
| | results.append(result[0])
|
| | return results
|
| |
|
| | def loglikelihood_rolling(
|
| | self, requests: List[Instance], disable_tqdm: bool = False
|
| | ) -> List[float]:
|
| | results = []
|
| | for request in tqdm(
|
| | requests, desc="loglikelihood_rolling", disable=disable_tqdm
|
| | ):
|
| | text = request.args[0]
|
| | encoding = self.tok_encode(text)
|
| |
|
| | total_log_prob = 0.0
|
| | max_len = self.max_length
|
| |
|
| | for i in range(0, len(encoding), max_len):
|
| | chunk = encoding[i : i + max_len]
|
| | input_ids = torch.tensor(
|
| | [chunk], dtype=torch.long, device=self._device
|
| | )
|
| |
|
| | logits = self._model_call(input_ids)
|
| | shift_logits = logits[:, :-1, :]
|
| | shift_labels = input_ids[:, 1:]
|
| |
|
| | log_probs = F.log_softmax(shift_logits, dim=-1)
|
| | token_log_probs = log_probs.gather(
|
| | 2, shift_labels.unsqueeze(-1)
|
| | ).squeeze(-1)
|
| |
|
| | total_log_prob += token_log_probs.sum().item()
|
| |
|
| | results.append(total_log_prob)
|
| | return results
|
| |
|
| | def generate_until(
|
| | self, requests: List[Instance], disable_tqdm: bool = False
|
| | ) -> List[str]:
|
| | results = []
|
| | for request in tqdm(
|
| | requests, desc="generate_until", disable=disable_tqdm
|
| | ):
|
| | context = request.args[0]
|
| | gen_kwargs = getattr(request, "kwargs", {}) or {}
|
| |
|
| | until = gen_kwargs.get("until", [self.tokenizer.eos_token])
|
| | max_gen = gen_kwargs.get("max_gen_toks", self.max_gen_toks)
|
| |
|
| | context_enc = self.tok_encode(context)
|
| |
|
| | if len(context_enc) > self.max_length - max_gen:
|
| | context_enc = context_enc[-(self.max_length - max_gen) :]
|
| | input_ids = torch.tensor(
|
| | [context_enc], dtype=torch.long, device=self._device
|
| | )
|
| |
|
| | if self.model_type == "graft_g2lu":
|
| |
|
| |
|
| | with torch.no_grad():
|
| | output_ids = self._raw_model.generate(
|
| | input_ids,
|
| | max_new_tokens=max_gen,
|
| | do_sample=False,
|
| | use_cache=True,
|
| | )
|
| | generated_text = self.tok_decode(
|
| | output_ids[0, input_ids.shape[1] :].tolist()
|
| | )
|
| | else:
|
| | generated_ids = input_ids.clone()
|
| | with torch.no_grad():
|
| | for _ in range(max_gen):
|
| |
|
| | if generated_ids.shape[1] > self.max_length:
|
| | generated_ids = generated_ids[:, -self.max_length :]
|
| |
|
| | logits = self._model_call(generated_ids)
|
| | next_logits = logits[:, -1, :]
|
| | next_token = next_logits.argmax(dim=-1, keepdim=True)
|
| | generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| |
|
| | if next_token.item() == self.eot_token_id:
|
| | break
|
| |
|
| | current_text = self.tok_decode(
|
| | generated_ids[0, len(context_enc) :].tolist()
|
| | )
|
| | if any(s in current_text for s in until):
|
| | break
|
| |
|
| | generated_text = self.tok_decode(
|
| | generated_ids[0, len(context_enc) :].tolist()
|
| | )
|
| |
|
| | for stop in until:
|
| | if stop in generated_text:
|
| | generated_text = generated_text[: generated_text.index(stop)]
|
| |
|
| | results.append(generated_text)
|
| |
|
| | return results
|
| |
|