| """Experiment functions for the reFlow interpretability demo, adapted for Gradio.""" |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.ticker as ticker |
| import seaborn as sns |
| from sklearn.decomposition import PCA |
| from sklearn.metrics import silhouette_score |
|
|
| try: |
| from adjustText import adjust_text |
| except ImportError: |
| adjust_text = lambda texts, **kwargs: None |
|
|
| from model_loader import get_model, get_cached_tensors |
|
|
| REAL_VOCAB = 50257 |
|
|
| |
| |
| |
|
|
| def _embed(model, ids): |
| result = model.transformer.wte(ids) |
| return result[0] if isinstance(result, tuple) else result |
|
|
|
|
| def _get_vocab_signals(model): |
| wte = model.transformer.wte |
| if hasattr(wte, '_apply_sparsity'): |
| return wte._apply_sparsity(wte.vocab_to_signals.weight.data) |
| return wte.vocab_to_signals.weight.data |
|
|
|
|
| def _forward_through_layers(model, ids): |
| with torch.no_grad(): |
| x = _embed(model, ids) |
| freqs_cis = model.freqs_cis[:ids.size(1)] |
| for block in model.transformer.h: |
| x = block(x, freqs_cis) |
| return x |
|
|
|
|
| def _get_logits_from_hidden(model, x_norm): |
| vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix() |
| return F.linear(x_norm, vocab_matrix) |
|
|
|
|
| def _gini(arr): |
| arr = np.sort(np.abs(arr)) |
| n = len(arr) |
| if n == 0 or np.sum(arr) == 0: |
| return 0.0 |
| index = np.arange(1, n + 1) |
| return (2 * np.sum(index * arr) / (n * np.sum(arr))) - (n + 1) / n |
|
|
|
|
| |
| |
| |
|
|
| DEFAULT_CLUSTERS = { |
| "Countries": ["China", "France", "Germany", "Japan", "India", "Russia"], |
| "Animals": ["cat", "dog", "fish", "bird", "horse", "bear"], |
| "Numbers": ["one", "two", "three", "four", "five", "ten"], |
| "Colors": ["red", "blue", "green", "black", "white", "yellow"], |
| "Emotions": ["happy", "sad", "angry", "love", "fear", "hate"], |
| } |
|
|
|
|
| @torch.inference_mode() |
| def exp_semantic_galaxy( |
| use_countries, use_animals, use_numbers, use_colors, use_emotions, custom_words |
| ): |
| model, enc, device = get_model() |
| W_v2s = _get_vocab_signals(model).cpu().numpy() |
|
|
| |
| clusters = {} |
| if use_countries: |
| clusters["Countries"] = DEFAULT_CLUSTERS["Countries"] |
| if use_animals: |
| clusters["Animals"] = DEFAULT_CLUSTERS["Animals"] |
| if use_numbers: |
| clusters["Numbers"] = DEFAULT_CLUSTERS["Numbers"] |
| if use_colors: |
| clusters["Colors"] = DEFAULT_CLUSTERS["Colors"] |
| if use_emotions: |
| clusters["Emotions"] = DEFAULT_CLUSTERS["Emotions"] |
|
|
| |
| if custom_words and custom_words.strip(): |
| custom_list = [w.strip() for w in custom_words.split(",") if w.strip()] |
| if custom_list: |
| clusters["Custom"] = custom_list |
|
|
| if not clusters: |
| clusters = DEFAULT_CLUSTERS |
|
|
| recipes, labels, words = [], [], [] |
| for cat, wl in clusters.items(): |
| for w in wl: |
| tids = enc.encode(" " + w) |
| if tids and tids[0] < REAL_VOCAB: |
| recipes.append(W_v2s[tids[0]]) |
| labels.append(cat) |
| words.append(w) |
|
|
| if len(words) < 3: |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| ax.text(0.5, 0.5, "Need at least 3 valid words", ha='center', va='center', fontsize=14) |
| ax.axis('off') |
| return fig |
|
|
| recipes_arr = np.array(recipes) |
| coords = PCA(n_components=2).fit_transform(recipes_arr) |
|
|
| label_ids = [list(clusters.keys()).index(l) for l in labels] |
| sil = silhouette_score(recipes_arr, label_ids) if len(set(label_ids)) >= 2 else 0.0 |
|
|
| fig = plt.figure(figsize=(12, 9)) |
| color_map = dict(zip(clusters.keys(), sns.color_palette("Set2", len(clusters)))) |
|
|
| texts = [] |
| for i, w in enumerate(words): |
| plt.scatter(coords[i, 0], coords[i, 1], color=color_map[labels[i]], |
| s=150, alpha=0.7, edgecolors='white', linewidths=0.5) |
| texts.append(plt.text(coords[i, 0], coords[i, 1], w, fontsize=11)) |
|
|
| if callable(adjust_text) and getattr(adjust_text, '__name__', '') != '<lambda>': |
| adjust_text(texts, arrowprops=dict(arrowstyle="-", color='gray')) |
|
|
| handles = [plt.Line2D([0], [0], marker='o', color='w', |
| markerfacecolor=color_map[l], markersize=12, label=l) for l in clusters] |
| plt.legend(handles=handles, title="Clusters", fontsize=10) |
| plt.title(f"reFlow Semantic Galaxy (PCA)\nSilhouette Score = {sil:.4f}", |
| fontsize=14, fontweight='bold') |
| plt.xlabel("PC1") |
| plt.ylabel("PC2") |
| plt.tight_layout() |
| return fig |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_semantic_algebra(positive_words, negative_words): |
| model, enc, device = get_model() |
| W_v2s = _get_vocab_signals(model) |
| W_valid = W_v2s[:REAL_VOCAB] |
|
|
| pos_list = [w.strip() for w in positive_words.split(",") if w.strip()] |
| neg_list = [w.strip() for w in negative_words.split(",") if w.strip()] |
|
|
| if not pos_list: |
| return "Please enter at least one positive word." |
|
|
| target_vec = torch.zeros(model.config.n_signals, device=device) |
| exclude_ids = set() |
|
|
| for w in pos_list: |
| tids = enc.encode(" " + w) |
| if tids and tids[0] < REAL_VOCAB: |
| target_vec += W_v2s[tids[0]] |
| exclude_ids.add(tids[0]) |
| for w in neg_list: |
| tids = enc.encode(" " + w) |
| if tids and tids[0] < REAL_VOCAB: |
| target_vec -= W_v2s[tids[0]] |
| exclude_ids.add(tids[0]) |
|
|
| sims = F.cosine_similarity(target_vec.unsqueeze(0), W_valid) |
| for tid in exclude_ids: |
| sims[tid] = -1.0 |
|
|
| top_vals, top_ids = torch.topk(sims, 20) |
|
|
| expr = " + ".join(pos_list) |
| if neg_list: |
| expr += " - " + " - ".join(neg_list) |
|
|
| rows = [] |
| for i in range(len(top_ids)): |
| try: |
| w = enc.decode([top_ids[i].item()]).strip() |
| if len(w) >= 1: |
| rows.append(f"#{len(rows)+1:2d} {w:<20s} cos={top_vals[i].item():.4f}") |
| except Exception: |
| continue |
| if len(rows) >= 15: |
| break |
|
|
| header = f"Expression: {expr}\n{'='*50}\nRank Word Similarity\n{'-'*50}\n" |
| return header + "\n".join(rows) |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_typo_resilience(sent_normal, sent_typo, sent_diff): |
| model, enc, device = get_model() |
| W_basis = model.transformer.wte.signal_basis.data |
|
|
| def get_deep_signal(text): |
| ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0) |
| x = _forward_through_layers(model, ids) |
| x_norm = model.transformer.ln_f(x[0, -1, :]) |
| return x_norm @ W_basis.t() |
|
|
| sig_normal = get_deep_signal(sent_normal) |
| sig_typo = get_deep_signal(sent_typo) |
| sig_diff = get_deep_signal(sent_diff) |
|
|
| sim_typo = F.cosine_similarity(sig_normal.unsqueeze(0), sig_typo.unsqueeze(0)).item() |
| sim_diff = F.cosine_similarity(sig_normal.unsqueeze(0), sig_diff.unsqueeze(0)).item() |
|
|
| fig, ax = plt.subplots(figsize=(8, 5)) |
| categories = ['Self\n(baseline)', 'Normal vs Typo\n(same meaning)', 'Normal vs Different\n(different meaning)'] |
| values = [1.0, sim_typo, sim_diff] |
| colors = ['#2ecc71', '#f39c12', '#e74c3c'] |
| bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', width=0.5) |
| for bar, val in zip(bars, values): |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, |
| f'{val:.4f}', ha='center', fontsize=11, fontweight='bold') |
| ax.set_ylim(0, 1.15) |
| ax.set_ylabel("Cosine Similarity") |
| ax.set_title("reFlow Typo Resilience - Deep Signal Similarity", fontsize=13, fontweight='bold') |
| ax.grid(axis='y', alpha=0.3) |
| plt.tight_layout() |
| return fig |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_sparsity_profile(word_to_inspect): |
| model, enc, device = get_model() |
| W_v2s = _get_vocab_signals(model) |
| W = W_v2s[:REAL_VOCAB] |
| vocab_size, n_signals = W.shape |
|
|
| mean_val = W.abs().mean().item() |
| std_val = W.abs().std().item() |
| threshold = mean_val + std_val |
| active_mask = W.abs() > threshold |
|
|
| active_per_word = active_mask.sum(dim=1).cpu().numpy() |
| active_per_signal = active_mask.sum(dim=0).cpu().numpy() |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| |
| int_bins = np.arange(active_per_word.min(), active_per_word.max() + 2) - 0.5 |
| axes[0].hist(active_per_word, bins=int_bins, color='teal', alpha=0.7, edgecolor='black') |
| axes[0].axvline(x=np.mean(active_per_word), color='red', linestyle='--', |
| label=f'Mean: {np.mean(active_per_word):.1f}') |
| axes[0].set_title("Per-Word Sparsity (# Active Signals)") |
| axes[0].set_xlabel("Number of Active Signals") |
| axes[0].set_ylabel("Frequency") |
| axes[0].legend() |
|
|
| |
| axes[1].bar(range(n_signals), active_per_signal, color='coral', alpha=0.7, width=1.0) |
| axes[1].set_title("Signal Utilization (# words activating each signal)") |
| axes[1].set_xlabel("Signal Index") |
| axes[1].set_ylabel("# Words") |
| axes[1].axhline(y=np.mean(active_per_signal), color='red', linestyle='--', |
| label=f'Mean: {np.mean(active_per_signal):.0f}') |
| axes[1].legend() |
|
|
| plt.suptitle("reFlow Sparsity Profile", fontsize=14, fontweight='bold') |
| plt.tight_layout(rect=[0, 0, 1, 0.95]) |
|
|
| |
| stats_text = f"Threshold: {threshold:.4f} (mean + std)\n" |
| stats_text += f"Avg active signals per word: {np.mean(active_per_word):.1f} / {n_signals}\n" |
| stats_text += f"Global activation rate: {active_mask.float().mean().item():.2%}\n" |
|
|
| if word_to_inspect and word_to_inspect.strip(): |
| w = word_to_inspect.strip() |
| tids = enc.encode(" " + w) |
| if tids and tids[0] < REAL_VOCAB: |
| word_recipe = W[tids[0]] |
| word_active = (word_recipe.abs() > threshold).sum().item() |
| top_sigs = torch.argsort(word_recipe.abs(), descending=True)[:10] |
| stats_text += f"\n--- '{w}' ---\n" |
| stats_text += f"Active signals: {word_active}\n" |
| stats_text += f"Top 10 signal indices: {top_sigs.tolist()}\n" |
| stats_text += f"Top 10 amplitudes: {[f'{word_recipe[s].item():.4f}' for s in top_sigs]}\n" |
| else: |
| stats_text += f"\n'{w}' not found in vocabulary.\n" |
|
|
| return fig, stats_text |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_layer_evolution(prompt_text): |
| model, enc, device = get_model() |
| vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix() |
| n_layers = len(model.transformer.h) |
|
|
| ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) |
| layer_probs = [] |
| layer_entropies = [] |
|
|
| x = _embed(model, ids) |
| freqs_cis = model.freqs_cis[:ids.size(1)] |
| for block in model.transformer.h: |
| x = block(x, freqs_cis) |
| x_norm = model.transformer.ln_f(x[0, -1, :]) |
| probs = F.softmax(_get_logits_from_hidden(model, x_norm), dim=-1) |
| layer_probs.append(probs.cpu().numpy()) |
| entropy = -torch.sum(probs * torch.log(probs + 1e-9)).item() |
| layer_entropies.append(entropy) |
|
|
| final_probs = layer_probs[-1][:REAL_VOCAB] |
| top_idx = np.argsort(final_probs)[-6:] |
| prob_flow = np.array([[p[i] for i in top_idx] for p in layer_probs]) |
| layers = np.arange(1, n_layers + 1) |
|
|
| fig, (ax_prob, ax_ent) = plt.subplots(1, 2, figsize=(16, 5)) |
|
|
| colors_palette = sns.color_palette("husl", len(top_idx)) |
| for i, idx in enumerate(top_idx): |
| label = repr(enc.decode([idx])).strip("'") |
| ax_prob.plot(layers, prob_flow[:, i], label=label, lw=2.5, color=colors_palette[i]) |
| ax_prob.set_title(f"Probability Evolution: '{prompt_text}'", fontsize=11, fontweight='bold') |
| ax_prob.set_xlabel("Layer") |
| ax_prob.set_ylabel("Probability") |
| ax_prob.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0)) |
| ax_prob.legend(fontsize=8, loc='upper left') |
| ax_prob.grid(True, alpha=0.3) |
|
|
| ax_ent.plot(layers, layer_entropies, color='#FF6B35', lw=2.5, marker='o', markersize=3) |
| ax_ent.set_title(f"Entropy Decay: '{prompt_text}'", fontsize=11, fontweight='bold') |
| ax_ent.set_xlabel("Layer") |
| ax_ent.set_ylabel("Entropy (nats)") |
| ax_ent.grid(True, alpha=0.3) |
|
|
| predicted = enc.decode([np.argmax(final_probs)]) |
| plt.suptitle(f"reFlow Layer Evolution | Prediction: '{predicted}' (p={final_probs.max():.2%})", |
| fontsize=13, fontweight='bold') |
| plt.tight_layout(rect=[0, 0, 1, 0.95]) |
| return fig |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_causal_ablation(prompt_text): |
| model, enc, device = get_model() |
| W_basis = model.transformer.wte.signal_basis.data |
| W_v2s = _get_vocab_signals(model) |
|
|
| ablation_steps = [1, 2, 4, 8, 16, 32, 64, 128] |
|
|
| ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) |
| x = _forward_through_layers(model, ids) |
| x_norm = model.transformer.ln_f(x[0, -1, :]) |
| sig_acts = x_norm @ W_basis.t() |
|
|
| logits_base = sig_acts @ W_v2s[:REAL_VOCAB].t() |
| probs_base = F.softmax(logits_base, dim=-1) |
| pred_id = torch.argmax(probs_base).item() |
| pred_word = enc.decode([pred_id]) |
| pred_prob = probs_base[pred_id].item() |
|
|
| contribs = sig_acts * W_v2s[pred_id] |
| sorted_sig_ids = torch.argsort(contribs, descending=True) |
|
|
| steps, probs_list, new_preds = [], [], [] |
| for n_ablate in ablation_steps: |
| if n_ablate > len(sorted_sig_ids): |
| break |
| ablated = sig_acts.clone() |
| ablated[sorted_sig_ids[:n_ablate]] = 0.0 |
| logits_abl = ablated @ W_v2s[:REAL_VOCAB].t() |
| probs_abl = F.softmax(logits_abl, dim=-1) |
| new_pred_id = torch.argmax(probs_abl).item() |
| steps.append(n_ablate) |
| probs_list.append(probs_abl[pred_id].item()) |
| new_preds.append(enc.decode([new_pred_id])) |
|
|
| |
| top_sig = sorted_sig_ids[0].item() |
| col = W_v2s[:REAL_VOCAB, top_sig] |
| top_vals, top_ids = torch.topk(col, 8) |
| cb_words = [] |
| for tid in top_ids: |
| try: |
| cb_words.append(enc.decode([tid.item()]).strip()) |
| except Exception: |
| cb_words.append(f"[{tid.item()}]") |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| ax1.plot(steps, [max(p, 1e-8) for p in probs_list], |
| 'o-', color='#e74c3c', lw=2.5, markersize=6) |
| ax1.axhline(y=pred_prob, color='blue', linestyle='--', alpha=0.5, |
| label=f"Baseline: {pred_prob:.1%}") |
| ax1.set_title(f"'{prompt_text}'\nPrediction: '{pred_word}'", fontsize=10, fontweight='bold') |
| ax1.set_xlabel("# Signals Ablated") |
| ax1.set_ylabel("P(original prediction)") |
| ax1.set_yscale('log') |
| ax1.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=2)) |
| ax1.set_xscale('log', base=2) |
| ax1.legend(fontsize=8) |
| ax1.grid(True, alpha=0.3) |
|
|
| |
| ax2.axis('off') |
| summary = f"Baseline: '{pred_word}' (p={pred_prob:.2%})\n" |
| summary += f"Key Signal: #{top_sig}\n" |
| summary += f"Codebook: {', '.join(cb_words[:6])}\n\n" |
| summary += "Ablation Results:\n" + "-"*40 + "\n" |
| for s, p, nw in zip(steps, probs_list, new_preds): |
| summary += f" {s:3d} signals removed -> p={p:.2%}, pred='{nw}'\n" |
|
|
| ax2.text(0.05, 0.95, summary, transform=ax2.transAxes, fontsize=10, |
| verticalalignment='top', fontfamily='monospace', |
| bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)) |
|
|
| plt.suptitle("reFlow Causal Ablation", fontsize=14, fontweight='bold') |
| plt.tight_layout(rect=[0, 0, 1, 0.95]) |
| return fig |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_concept_inception(prompt_text, target_word, alpha_max): |
| model, enc, device = get_model() |
| W_basis = model.transformer.wte.signal_basis.data |
| W_v2s = _get_vocab_signals(model) |
|
|
| tid = enc.encode(" " + target_word)[0] |
| target_recipe = W_v2s[tid] |
|
|
| ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) |
| x = _forward_through_layers(model, ids) |
| x_norm = model.transformer.ln_f(x[0, -1, :]) |
| base_sig = x_norm @ W_basis.t() |
|
|
| logits_base = base_sig @ W_v2s[:REAL_VOCAB].t() |
| probs_base = F.softmax(logits_base, dim=-1) |
| orig_word = enc.decode([torch.argmax(probs_base).item()]) |
| orig_prob = probs_base[tid].item() |
|
|
| |
| lo, hi = 0.0, float(alpha_max) |
| critical_alpha = None |
| probs_hi = F.softmax((base_sig + hi * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1) |
| if torch.argmax(probs_hi).item() == tid: |
| for _ in range(20): |
| mid = (lo + hi) / 2 |
| probs_mid = F.softmax((base_sig + mid * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1) |
| if torch.argmax(probs_mid).item() == tid: |
| hi = mid |
| else: |
| lo = mid |
| critical_alpha = hi |
|
|
| |
| alpha_range = min(float(alpha_max), (critical_alpha or float(alpha_max)) * 1.5) |
| alphas = np.linspace(0, alpha_range, 50) |
| target_probs = [] |
| for a in alphas: |
| probs = F.softmax((base_sig + a * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1) |
| target_probs.append(probs[tid].item()) |
|
|
| fig, ax = plt.subplots(figsize=(8, 5)) |
| ax.plot(alphas, target_probs, 'o-', color='#9b59b6', lw=2, markersize=3) |
| if critical_alpha: |
| ax.axvline(critical_alpha, color='red', linestyle='--', |
| label=f"Critical alpha={critical_alpha:.1f}") |
| ax.axhline(y=orig_prob, color='gray', linestyle=':', alpha=0.5, |
| label=f"Baseline P('{target_word}')={orig_prob:.1e}") |
| ax.set_title(f"'{prompt_text}'\n'{orig_word}' -> '{target_word}'", |
| fontsize=11, fontweight='bold') |
| ax.set_xlabel("Steering Strength (alpha)") |
| ax.set_ylabel(f"P('{target_word}')") |
| ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0)) |
| ax.legend(fontsize=9) |
| ax.grid(True, alpha=0.3) |
| plt.tight_layout() |
|
|
| info = f"Original prediction: '{orig_word}'\n" |
| info += f"Target: '{target_word}'\n" |
| if critical_alpha: |
| info += f"Critical flip point: alpha = {critical_alpha:.2f}\n" |
| else: |
| info += f"Target not reached within alpha <= {alpha_max}\n" |
|
|
| return fig, info |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_generate(prompt_text, num_samples, max_tokens, temperature, top_k, repetition_penalty): |
| model, enc, device = get_model() |
|
|
| num_samples = int(num_samples) |
| max_tokens = int(max_tokens) |
| top_k = int(top_k) if top_k and top_k > 0 else None |
| temperature = float(temperature) |
| repetition_penalty = float(repetition_penalty) |
|
|
| if not prompt_text.strip(): |
| return "Please enter a prompt." |
|
|
| ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) |
| |
| ids = ids.expand(num_samples, -1).contiguous() |
|
|
| results = [] |
| for s in range(num_samples): |
| x = ids[s:s+1] |
| for _ in range(max_tokens): |
| x_cond = x if x.size(1) <= model.config.block_size else x[:, -model.config.block_size:] |
| logits, _ = model(x_cond) |
| logits = logits[:, -1, :] |
|
|
| |
| if repetition_penalty != 1.0: |
| generated_ids = x[0].tolist() |
| for token_id in set(generated_ids): |
| if logits[0, token_id] > 0: |
| logits[0, token_id] /= repetition_penalty |
| else: |
| logits[0, token_id] *= repetition_penalty |
|
|
| |
| logits = logits / max(temperature, 1e-8) |
|
|
| |
| if top_k is not None and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| x = torch.cat((x, idx_next), dim=1) |
|
|
| text = enc.decode(x[0].tolist()) |
| results.append(text) |
|
|
| separator = "\n" + "=" * 60 + "\n" |
| output = "" |
| for i, text in enumerate(results): |
| if num_samples > 1: |
| output += f"--- Sample {i+1}/{num_samples} ---\n" |
| output += text + "\n" |
| if i < len(results) - 1: |
| output += separator |
| return output |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_basis_geometry(): |
| model, enc, device = get_model() |
|
|
| W_basis = model.transformer.wte.signal_basis.data.cpu().float() |
| n_signals, n_embd = W_basis.shape |
|
|
| U, S, Vt = torch.linalg.svd(W_basis, full_matrices=False) |
| S_np = S.numpy() |
|
|
| s_norm = S_np / S_np.sum() |
| effective_rank = np.exp(-np.sum(s_norm * np.log(s_norm + 1e-12))) |
|
|
| random_mat = torch.randn_like(W_basis) |
| _, S_rand, _ = torch.linalg.svd(random_mat, full_matrices=False) |
| S_rand_np = S_rand.numpy() |
| s_rand_norm = S_rand_np / S_rand_np.sum() |
| effective_rank_rand = np.exp(-np.sum(s_rand_norm * np.log(s_rand_norm + 1e-12))) |
|
|
| show_n = min(64, n_signals) |
| W_show = W_basis[:show_n] |
| W_normed = F.normalize(W_show, dim=1) |
| cos_sim = (W_normed @ W_normed.t()).numpy() |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) |
|
|
| ax1.plot(S_np / S_np[0], 'b-', lw=2, label='Learned Basis') |
| ax1.plot(S_rand_np / S_rand_np[0], 'r--', lw=1.5, label='Random Gaussian') |
| ax1.set_title(f"Singular Value Spectrum\n(Eff. rank: learned={effective_rank:.0f}, random={effective_rank_rand:.0f})") |
| ax1.set_xlabel("Component Index") |
| ax1.set_ylabel("Normalized Singular Value") |
| ax1.set_yscale('log') |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
|
|
| im = ax2.imshow(cos_sim, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto') |
| ax2.set_title(f"Cosine Similarity (first {show_n} signals)") |
| ax2.set_xlabel("Signal Index") |
| ax2.set_ylabel("Signal Index") |
| plt.colorbar(im, ax=ax2, fraction=0.046) |
|
|
| plt.suptitle("reFlow Signal Basis Geometry", fontsize=14, fontweight='bold') |
| plt.tight_layout(rect=[0, 0, 1, 0.95]) |
|
|
| stats = f"Signal basis shape: ({n_signals}, {n_embd})\n" |
| stats += f"Effective rank (learned): {effective_rank:.1f} / {min(n_signals, n_embd)}\n" |
| stats += f"Effective rank (random): {effective_rank_rand:.1f} / {min(n_signals, n_embd)}\n" |
|
|
| return fig, stats |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_recipe_neighbors(query_word, top_n): |
| model, enc, device = get_model() |
| W_v2s = _get_vocab_signals(model) |
| W = W_v2s[:REAL_VOCAB] |
| W_normed = F.normalize(W, dim=1) |
|
|
| top_n = int(top_n) |
| words = [w.strip() for w in query_word.split(",") if w.strip()] |
| if not words: |
| return "Please enter at least one word." |
|
|
| output = "" |
| for w in words: |
| tids = enc.encode(" " + w) |
| if not tids or tids[0] >= REAL_VOCAB: |
| output += f"'{w}' not found in vocabulary.\n\n" |
| continue |
| tid = tids[0] |
| sims = (W_normed[tid] @ W_normed.t()) |
| sims[tid] = -1 |
| top_vals, top_ids = torch.topk(sims, top_n) |
|
|
| output += f"Nearest neighbors for '{w}':\n" + "-" * 40 + "\n" |
| for i, (v, nid) in enumerate(zip(top_vals, top_ids)): |
| try: |
| nw = enc.decode([nid.item()]).strip() |
| except Exception: |
| nw = f"[{nid.item()}]" |
| output += f" #{i+1:2d} {nw:<20s} cos={v.item():.4f}\n" |
| output += "\n" |
|
|
| return output |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def exp_task_crystallization(prompt_text, target_word, max_alpha, start_layer): |
| model, enc, device = get_model() |
| W_basis = model.transformer.wte.signal_basis.data |
| W_v2s = _get_vocab_signals(model) |
| n_layers = len(model.transformer.h) |
| start_layer = int(start_layer) |
| max_alpha = float(max_alpha) |
|
|
| target_tid = enc.encode(" " + target_word.strip())[0] |
| ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) |
|
|
| |
| x = _forward_through_layers(model, ids) |
| x_norm = model.transformer.ln_f(x[0, -1, :]) |
| logits_base = _get_logits_from_hidden(model, x_norm) |
| base_pred_id = torch.argmax(logits_base).item() |
| base_pred = enc.decode([base_pred_id]) |
|
|
| |
| def continuous_steer(alpha, intercept_layer): |
| steer_vec = W_v2s[target_tid] - W_v2s[base_pred_id] |
| x = _embed(model, ids) |
| if intercept_layer == 0: |
| x[:, -1, :] += (alpha * steer_vec) @ W_basis |
|
|
| freqs_cis = model.freqs_cis[:ids.size(1)] |
| for i, block in enumerate(model.transformer.h): |
| x = block(x, freqs_cis) |
| if i + 1 >= intercept_layer: |
| x[:, -1, :] += (alpha * steer_vec) @ W_basis |
|
|
| x_norm = model.transformer.ln_f(x[0, -1, :]) |
| logits = _get_logits_from_hidden(model, x_norm) |
| probs = F.softmax(logits, dim=-1) |
| pred_id = torch.argmax(logits).item() |
| return probs[target_tid].item(), enc.decode([pred_id]).strip() |
|
|
| |
| working_alpha = None |
| for a in np.arange(2.0, max_alpha, 2.0): |
| _, pred = continuous_steer(a, start_layer) |
| if pred.strip() == target_word.strip(): |
| working_alpha = a * 1.2 |
| break |
|
|
| if working_alpha is None: |
| return None, f"Cannot steer to '{target_word}' within alpha <= {max_alpha}" |
|
|
| |
| layer_probs = [] |
| c_layer = n_layers |
| for L in range(n_layers): |
| p_target, pred = continuous_steer(working_alpha, L) |
| layer_probs.append(p_target) |
| if pred.strip() != target_word.strip() and c_layer == n_layers: |
| c_layer = L |
|
|
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| layers_x = np.arange(n_layers) |
| ax.plot(layers_x, layer_probs, 'o-', color='#9b59b6', lw=2.5, markersize=4) |
| if c_layer < n_layers: |
| ax.scatter(c_layer, layer_probs[c_layer], color='red', s=150, marker='X', edgecolors='black', zorder=5) |
| ax.axvline(c_layer, color='red', linestyle='--', alpha=0.5, label=f'Crystallization boundary: Layer {c_layer}') |
|
|
| ax.set_title(f"Task Crystallization: '{prompt_text}' → '{target_word}'", fontsize=11, fontweight='bold') |
| ax.set_xlabel("Intervention Start Layer") |
| ax.set_ylabel(f"P('{target_word}')") |
| ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0)) |
| ax.legend(fontsize=9) |
| ax.grid(True, alpha=0.3) |
| plt.tight_layout() |
|
|
| info = f"Base prediction: '{base_pred}'\n" |
| info += f"Target: '{target_word}'\n" |
| info += f"Working alpha: {working_alpha:.1f}\n" |
| info += f"Crystallization boundary: Layer {c_layer}\n" |
|
|
| return fig, info |
|
|
|
|