geolip-captionbert-8192 / colab_deep_analysis.py
AbstractPhil's picture
Create colab_deep_analysis.py
7f9ee11 verified
# ============================================================================
# INTERNAL ANALYZER: CaptionBERT-8192
#
# Sees inside the model, not just the output. Five diagnostic lenses:
# 1. Spectral trajectories β€” eigenvalue evolution per layer
# 2. Effective dimensionality β€” how deeply each input is understood
# 3. Cross-layer divergence β€” where computation actually happens
# 4. Token influence β€” which input tokens drive the output
# 5. Neighborhood structure β€” local geometry at each layer
#
# Usage:
# analyzer = InternalAnalyzer(model, tokenizer)
# report = analyzer.analyze(["girl", "woman", "subtraction", "multiplication"])
# analyzer.print_report(report)
# analyzer.compare(report, "girl", "subtraction")
# ============================================================================
import torch
import torch.nn.functional as F
import numpy as np
from collections import defaultdict
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class InternalAnalyzer:
def __init__(self, model, tokenizer, max_len=512):
self.model = model.to(DEVICE).eval()
self.tokenizer = tokenizer
self.max_len = max_len
# ══════════════════════════════════════════════════════════════
# CORE: Extract all layer representations
# ══════════════════════════════════════════════════════════════
@torch.no_grad()
def extract_layers(self, texts):
"""Get per-layer mean-pooled representations for each input."""
if isinstance(texts, str):
texts = [texts]
inputs = self.tokenizer(
texts, max_length=self.max_len, padding="max_length",
truncation=True, return_tensors="pt").to(DEVICE)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_hidden_states=True)
mask = inputs["attention_mask"].unsqueeze(-1).float()
n_tokens = inputs["attention_mask"].sum(-1)
# Mean-pool each layer
layer_pooled = []
for h in outputs.hidden_states:
pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1)
layer_pooled.append(pooled.cpu())
return {
"texts": texts,
"layer_pooled": layer_pooled, # list of (B, 384) per layer
"layer_raw": outputs.hidden_states, # tuple of (B, L, 384) per layer
"final_embedding": outputs.last_hidden_state.cpu(), # (B, 768)
"attention_mask": inputs["attention_mask"].cpu(),
"n_tokens": n_tokens.cpu(),
}
# ══════════════════════════════════════════════════════════════
# 1. SPECTRAL TRAJECTORIES
# ══════════════════════════════════════════════════════════════
def spectral_trajectory(self, data):
"""
Eigenvalue spectrum at each layer for each input.
Shows how the representation's internal structure evolves.
"""
results = []
n_layers = len(data["layer_pooled"])
B = data["layer_pooled"][0].shape[0]
for b in range(B):
trajectory = []
for layer_idx in range(n_layers):
# For single vector: compute singular values of the
# raw token-level representation (before pooling)
h = data["layer_raw"][layer_idx][b].cpu().float() # (L, 384)
mask = data["attention_mask"][b]
n_real = mask.sum().int().item()
h = h[:n_real] # only real tokens
if n_real < 2:
trajectory.append({"spectrum": [], "eff_dim": 0, "entropy": 0})
continue
# SVD of token representations
h_centered = h - h.mean(0, keepdim=True)
try:
S = torch.linalg.svdvals(h_centered)
except Exception:
trajectory.append({"spectrum": [], "eff_dim": 0, "entropy": 0})
continue
# Normalized spectrum
S_norm = S / (S.sum() + 1e-12)
# Effective dimensionality (participation ratio)
eff_dim = (S.sum() ** 2) / (S.pow(2).sum() + 1e-12)
# Spectral entropy
S_pos = S_norm[S_norm > 1e-12]
entropy = -(S_pos * S_pos.log()).sum()
trajectory.append({
"spectrum": S[:20].tolist(), # top 20 singular values
"eff_dim": eff_dim.item(),
"entropy": entropy.item(),
"top1_ratio": (S[0] / (S.sum() + 1e-12)).item(),
})
results.append({
"text": data["texts"][b],
"trajectory": trajectory,
})
return results
# ══════════════════════════════════════════════════════════════
# 2. EFFECTIVE DIMENSIONALITY (output space)
# ══════════════════════════════════════════════════════════════
def effective_dimensionality(self, data, k_neighbors=50):
"""
Local effective dimensionality around each embedding.
High = rich understanding. Low = surface-level placement.
"""
embeddings = data["final_embedding"].float() # (B, 768)
B = embeddings.shape[0]
if B < k_neighbors + 1:
k_neighbors = max(B - 1, 2)
# Pairwise distances
sim = embeddings @ embeddings.T
results = []
for b in range(B):
# Get k nearest neighbors
sims = sim[b].clone()
sims[b] = -1 # exclude self
_, topk_idx = sims.topk(k_neighbors)
neighbors = embeddings[topk_idx] # (k, 768)
# Local PCA
centered = neighbors - neighbors.mean(0, keepdim=True)
try:
S = torch.linalg.svdvals(centered)
except Exception:
results.append({"eff_dim": 0, "local_variance": 0})
continue
# Participation ratio
eff_dim = (S.sum() ** 2) / (S.pow(2).sum() + 1e-12)
# How fast do eigenvalues decay?
S_norm = S / (S.sum() + 1e-12)
decay_rate = (S_norm[:5].sum() / S_norm.sum()).item()
results.append({
"text": data["texts"][b],
"eff_dim": eff_dim.item(),
"decay_rate": decay_rate, # high = concentrated, low = spread
"local_spread": centered.norm(dim=-1).mean().item(),
})
return results
# ══════════════════════════════════════════════════════════════
# 3. CROSS-LAYER DIVERGENCE
# ══════════════════════════════════════════════════════════════
def cross_layer_divergence(self, data):
"""
How much does the representation change between layers?
High change = computation happening. Low change = pass-through.
"""
results = []
n_layers = len(data["layer_pooled"])
B = data["layer_pooled"][0].shape[0]
for b in range(B):
profile = []
for i in range(n_layers - 1):
h_curr = data["layer_pooled"][i][b].float()
h_next = data["layer_pooled"][i + 1][b].float()
# Cosine between consecutive layers
cos = F.cosine_similarity(h_curr.unsqueeze(0),
h_next.unsqueeze(0)).item()
# L2 distance
l2 = (h_next - h_curr).norm().item()
# Direction change (how much the direction rotates)
h_curr_n = F.normalize(h_curr, dim=0)
h_next_n = F.normalize(h_next, dim=0)
angle = torch.acos(torch.clamp(
(h_curr_n * h_next_n).sum(), -1, 1)).item()
profile.append({
"layer": f"{i}β†’{i+1}",
"cosine": cos,
"l2_shift": l2,
"angle_rad": angle,
})
# Total path length through representation space
total_path = sum(p["l2_shift"] for p in profile)
# Where did most change happen?
max_shift_layer = max(range(len(profile)),
key=lambda i: profile[i]["l2_shift"])
results.append({
"text": data["texts"][b],
"profile": profile,
"total_path": total_path,
"max_shift_layer": max_shift_layer,
"input_output_cos": F.cosine_similarity(
data["layer_pooled"][0][b].unsqueeze(0).float(),
data["layer_pooled"][-1][b].unsqueeze(0).float()
).item(),
})
return results
# ══════════════════════════════════════════════════════════════
# 4. TOKEN INFLUENCE (gradient-based)
# ══════════════════════════════════════════════════════════════
def token_influence(self, texts):
"""
Which tokens influence the output most?
Uses gradient of output norm w.r.t. input embeddings.
"""
if isinstance(texts, str):
texts = [texts]
results = []
for text in texts:
inputs = self.tokenizer(
[text], max_length=self.max_len, padding="max_length",
truncation=True, return_tensors="pt").to(DEVICE)
# Get embedding layer output with gradients
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
n_real = attention_mask.sum().item()
# Hook into embedding
emb = self.model.token_emb(input_ids) + \
self.model.pos_emb(torch.arange(input_ids.shape[1],
device=DEVICE).unsqueeze(0))
emb = self.model.emb_drop(self.model.emb_norm(emb))
emb.retain_grad()
# Forward through encoder
kpm = ~attention_mask.bool()
x = emb
for layer in self.model.encoder.layers:
x = layer(x, src_key_padding_mask=kpm)
# Pool and project
mask = attention_mask.unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
output = F.normalize(self.model.output_proj(pooled), dim=-1)
# Gradient of output norm w.r.t embeddings
output.sum().backward()
grad = emb.grad[0].cpu()
# Per-token influence = gradient norm
influence = grad.norm(dim=-1)[:int(n_real)] # only real tokens
influence = influence / (influence.sum() + 1e-12) # normalize
# Decode tokens
token_ids = input_ids[0][:int(n_real)].cpu().tolist()
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
results.append({
"text": text,
"tokens": tokens,
"influence": influence.tolist(),
"top_tokens": sorted(zip(tokens, influence.tolist()),
key=lambda x: -x[1])[:10],
"concentration": (influence.max() / influence.mean()).item(),
})
self.model.zero_grad()
return results
# ══════════════════════════════════════════════════════════════
# 5. FULL ANALYSIS
# ══════════════════════════════════════════════════════════════
def analyze(self, texts):
"""Run all analyses on a set of texts."""
if isinstance(texts, str):
texts = [texts]
print(f" Analyzing {len(texts)} inputs...")
data = self.extract_layers(texts)
spectral = self.spectral_trajectory(data)
eff_dim = self.effective_dimensionality(data)
divergence = self.cross_layer_divergence(data)
influence = self.token_influence(texts)
report = {}
for i, text in enumerate(texts):
report[text] = {
"embedding": data["final_embedding"][i],
"n_tokens": data["n_tokens"][i].item(),
"spectral": spectral[i],
"eff_dim": eff_dim[i] if i < len(eff_dim) else {},
"divergence": divergence[i],
"influence": influence[i],
}
return report
# ══════════════════════════════════════════════════════════════
# PRINTING
# ══════════════════════════════════════════════════════════════
def print_report(self, report):
"""Print full analysis report."""
print(f"\n{'='*70}")
print("INTERNAL ANALYSIS REPORT")
print(f"{'='*70}")
# Summary table
print(f"\n {'Text':<25} {'Tokens':>6} {'EffDim':>7} {'Path':>7} "
f"{'MaxShift':>9} {'InOutCos':>8} {'Concentrate':>11}")
print(f" {'-'*75}")
for text, r in report.items():
label = text[:24]
ed = r["eff_dim"].get("eff_dim", 0)
tp = r["divergence"]["total_path"]
ms = r["divergence"]["max_shift_layer"]
ioc = r["divergence"]["input_output_cos"]
conc = r["influence"]["concentration"]
print(f" {label:<25} {r['n_tokens']:>6} {ed:>7.1f} {tp:>7.2f} "
f" layer {ms:>2} {ioc:>7.3f} {conc:>10.1f}")
# Spectral evolution
print(f"\n SPECTRAL TRAJECTORY (effective dim per layer):")
print(f" {'Text':<25}", end="")
n_layers = len(next(iter(report.values()))["spectral"]["trajectory"])
for i in range(n_layers):
print(f" L{i:>2}", end="")
print()
print(f" {'-'*75}")
for text, r in report.items():
label = text[:24]
print(f" {label:<25}", end="")
for step in r["spectral"]["trajectory"]:
ed = step.get("eff_dim", 0)
print(f" {ed:>4.0f}", end="")
print()
# Spectral entropy per layer
print(f"\n SPECTRAL ENTROPY (information content per layer):")
print(f" {'Text':<25}", end="")
for i in range(n_layers):
print(f" L{i:>2}", end="")
print()
print(f" {'-'*75}")
for text, r in report.items():
label = text[:24]
print(f" {label:<25}", end="")
for step in r["spectral"]["trajectory"]:
ent = step.get("entropy", 0)
print(f" {ent:>4.1f}", end="")
print()
# Cross-layer divergence profiles
print(f"\n COMPUTATION PROFILE (L2 shift between layers):")
print(f" {'Text':<25}", end="")
for i in range(n_layers - 1):
print(f" {i}β†’{i+1:>2}", end="")
print()
print(f" {'-'*75}")
for text, r in report.items():
label = text[:24]
print(f" {label:<25}", end="")
for step in r["divergence"]["profile"]:
print(f" {step['l2_shift']:>4.1f}", end="")
print()
# Token influence for each input
print(f"\n TOKEN INFLUENCE (top contributing tokens):")
for text, r in report.items():
top = r["influence"]["top_tokens"][:5]
tok_str = " ".join(f"{t}={v:.3f}" for t, v in top)
print(f" {text[:40]:<42} {tok_str}")
def compare(self, report, text_a, text_b):
"""Compare internal representations of two specific inputs."""
a = report[text_a]
b = report[text_b]
cos = F.cosine_similarity(
a["embedding"].unsqueeze(0),
b["embedding"].unsqueeze(0)).item()
print(f"\n{'='*70}")
print(f"COMPARISON: '{text_a}' vs '{text_b}'")
print(f"{'='*70}")
print(f" Output cosine: {cos:.4f}")
print(f" Tokens: {a['n_tokens']} vs {b['n_tokens']}")
# Effective dim comparison
ed_a = a["eff_dim"].get("eff_dim", 0)
ed_b = b["eff_dim"].get("eff_dim", 0)
print(f" Effective dim: {ed_a:.1f} vs {ed_b:.1f} (Ξ”={abs(ed_a-ed_b):.1f})")
# Path comparison
pa = a["divergence"]["total_path"]
pb = b["divergence"]["total_path"]
print(f" Total path: {pa:.2f} vs {pb:.2f} (Ξ”={abs(pa-pb):.2f})")
# Layer-by-layer spectral comparison
print(f"\n Effective dim trajectory:")
print(f" {'Layer':<8} {'A':>8} {'B':>8} {'Ξ”':>8}")
traj_a = a["spectral"]["trajectory"]
traj_b = b["spectral"]["trajectory"]
for i in range(len(traj_a)):
ea = traj_a[i].get("eff_dim", 0)
eb = traj_b[i].get("eff_dim", 0)
print(f" L{i:<6} {ea:>8.1f} {eb:>8.1f} {abs(ea-eb):>8.1f}")
# Divergence profile comparison
print(f"\n Computation profile (L2 shift):")
print(f" {'Transition':<10} {'A':>8} {'B':>8} {'Ξ”':>8}")
for i in range(len(a["divergence"]["profile"])):
sa = a["divergence"]["profile"][i]["l2_shift"]
sb = b["divergence"]["profile"][i]["l2_shift"]
label = a["divergence"]["profile"][i]["layer"]
print(f" {label:<10} {sa:>8.2f} {sb:>8.2f} {abs(sa-sb):>8.2f}")
# Token influence comparison
print(f"\n Top tokens:")
print(f" A: {' '.join(f'{t}={v:.3f}' for t,v in a['influence']['top_tokens'][:5])}")
print(f" B: {' '.join(f'{t}={v:.3f}' for t,v in b['influence']['top_tokens'][:5])}")
# ══════════════════════════════════════════════════════════════════
# RUN
# ══════════════════════════════════════════════════════════════════
if __name__ == "__main__":
from transformers import AutoModel, AutoTokenizer
REPO_ID = "AbstractPhil/geolip-captionbert-8192"
print("Loading model...")
model = AutoModel.from_pretrained(REPO_ID, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
analyzer = InternalAnalyzer(model, tokenizer)
# Test words spanning known-domain and unknown-domain
test_words = [
# Known domain (captions)
"girl",
"woman",
"dog",
"sunset",
"painting",
# Unknown domain (abstract)
"subtraction",
"multiplication",
"prophetic",
"differential",
"adjacency",
# Phrases
"a girl sitting near a window",
"a dog playing on the beach",
"the differential equation of motion",
]
report = analyzer.analyze(test_words)
analyzer.print_report(report)
# Direct comparisons
analyzer.compare(report, "girl", "woman")
analyzer.compare(report, "girl", "subtraction")
analyzer.compare(report, "a girl sitting near a window",
"the differential equation of motion")
print(f"\n{'='*70}")
print("DONE")
print(f"{'='*70}")