model_tools / tokeninspector.py
Naphula's picture
Upload 8 files
5f463e1 verified
# python tokeninspector.py "B:\12B\models--mistralai--Mistral-Nemo-Instruct-2407" "B:\12B\models--aixonlab--Aether-12b.backup" "B:\12B\models--aixonlab--Aether-12b"
import os
import click
import torch
import transformers
from mergekit.io.lazy_tensor_loader import LazyTensorLoader
def get_embed_tensor(model_path):
"""Lazily loads the embedding tensor from a model directory."""
try:
loader = LazyTensorLoader.from_disk(model_path)
for key in loader.index.tensor_paths.keys():
if "embed_tokens.weight" in key or "wte.weight" in key:
return loader.get_tensor(key)
except Exception as e:
print(f" [!] Error loading tensors from {model_path}: {e}")
return None
@click.command()
@click.argument("base_model", type=click.Path(exists=True))
@click.argument("donor_model", type=click.Path(exists=True))
@click.argument("output_model", type=click.Path(exists=True))
def main(base_model, donor_model, output_model):
print("="*60)
print("🔍 TOKEN SURGEON AUDIT TOOL")
print("="*60)
print("\n[1] Loading Tokenizers...")
tok_base = transformers.AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tok_donor = transformers.AutoTokenizer.from_pretrained(donor_model, trust_remote_code=True)
tok_out = transformers.AutoTokenizer.from_pretrained(output_model, trust_remote_code=True)
print(f" Base: {len(tok_base)} tokens")
print(f" Donor: {len(tok_donor)} tokens")
print(f" Output: {len(tok_out)} tokens")
if len(tok_out) != len(tok_donor):
print(" ❌ FAIL: Output vocab size does not match Donor vocab size!")
else:
print(" ✅ PASS: Output vocab size matches Donor.")
print("\n[2] Loading Embedding Tensors (Lazy Load)...")
emb_base = get_embed_tensor(base_model)
emb_donor = get_embed_tensor(donor_model)
emb_out = get_embed_tensor(output_model)
print(f" Base Matrix: {emb_base.shape if emb_base is not None else 'Not found'}")
print(f" Donor Matrix: {emb_donor.shape if emb_donor is not None else 'Not found'}")
print(f" Output Matrix: {emb_out.shape if emb_out is not None else 'Not found'}")
if emb_out is not None and emb_donor is not None:
if emb_out.shape[0] >= len(tok_donor):
print(" ✅ PASS: Output embedding matrix size is sufficient for Donor vocab.")
else:
print(" ❌ FAIL: Output embedding matrix is smaller than Donor vocab!")
vocab_base = tok_base.get_vocab()
vocab_donor = tok_donor.get_vocab()
shared_tokens = set(vocab_base.keys()).intersection(set(vocab_donor.keys()))
donor_only_tokens = set(vocab_donor.keys()) - set(vocab_base.keys())
print("\n[3] Testing a Shared Token (Verifying exact transfer)...")
if shared_tokens:
# Pick a common word that is likely to exist in both
test_shared = None
for candidate in [" the", " hello", "The", "Hello", "Ġthe", "Ġhello", "the", "hello"]:
if candidate in shared_tokens:
test_shared = candidate
break
if not test_shared:
test_shared = list(shared_tokens)[len(shared_tokens)//2]
id_base = vocab_base[test_shared]
id_out = vocab_donor[test_shared] # output uses donor vocab
print(f" Token: '{test_shared}'")
print(f" ID in Base: {id_base} | ID in Output: {id_out}")
if emb_base is not None and emb_out is not None:
vec_base = emb_base[id_base].float()
vec_out = emb_out[id_out].float()
cos_sim = torch.nn.functional.cosine_similarity(vec_base, vec_out, dim=0).item()
print(f" Cosine similarity between Base and Output vectors: {cos_sim:.6f}")
if cos_sim > 0.999:
print(" ✅ PASS: Embeddings match perfectly. The vector was successfully moved to the new ID.")
else:
print(" ❌ FAIL: Embeddings for shared token do not match!")
else:
print(" ⚠️ No shared tokens found between vocabularies.")
print("\n[4] Testing a New Token (Verifying OMP approximation)...")
if donor_only_tokens:
# Try to find a special token or a distinct word
test_new = list(donor_only_tokens)[0]
for t in donor_only_tokens:
if "<" in t or "[" in t or "im_start" in t:
test_new = t
break
id_out = vocab_donor[test_new]
print(f" Token: '{test_new}' (Only exists in Donor)")
print(f" ID in Output: {id_out}")
if emb_out is not None:
vec_out = emb_out[id_out].float()
norm = vec_out.norm().item()
print(f" Vector L2 Norm: {norm:.4f}")
if norm > 0.01:
print(" ✅ PASS: Vector is non-zero. OMP successfully approximated a new embedding.")
else:
print(" ⚠️ WARN: Vector is zero or very close to zero. It may have been treated as a junk token.")
else:
print(" ⚠️ No donor-only tokens found. Vocabularies are identical.")
print("\n[5] Testing Tokenizer Encoding Behavior...")
test_text = "Hello world! This is a test of the new tokenizer. <|im_start|>system\n12345<|im_end|>"
enc_donor = tok_donor.encode(test_text)
enc_out = tok_out.encode(test_text)
if enc_donor == enc_out:
print(" ✅ PASS: Output model encodes text exactly identically to the Donor model.")
else:
print(" ❌ FAIL: Output model encoding differs from Donor model!")
print(f" Donor: {enc_donor[:10]}...")
print(f" Output: {enc_out[:10]}...")
print("\n" + "="*60)
print("Audit Complete.")
print("="*60)
if __name__ == '__main__':
main()