flashtrace / test_recompute.py
wenbopan's picture
Sync FlashTrace package from GitHub
55b60a8
"""End-to-end test: verify recompute_attention mode produces identical IFR results
through the full LLMIFRAttribution pipeline, and benchmark time/memory."""
import gc
import time
import tracemalloc
import torch
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerFast
from tokenizers import Tokenizer, models, pre_tokenizers
def make_model_and_tokenizer(n_layers, d_model, n_heads, n_kv_heads, max_pos):
config = AutoConfig.for_model(
"qwen2",
vocab_size=500,
hidden_size=d_model,
intermediate_size=d_model * 2,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
num_key_value_heads=n_kv_heads,
max_position_embeddings=max_pos,
use_sliding_window=False,
attn_implementation="eager",
)
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager")
model.eval()
tok_backend = Tokenizer(models.WordLevel(
vocab={f"t{i}": i for i in range(500)}, unk_token="t0",
))
tok_backend.pre_tokenizer = pre_tokenizers.Whitespace()
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tok_backend, eos_token="t1", pad_token="t2",
)
tokenizer.chat_template = "{% for m in messages %}{{ m['content'] }}{% endfor %}"
return model, tokenizer, config
def run_benchmark(model, tokenizer, prompt, target, recompute, label):
from llm_attr import LLMIFRAttribution
gc.collect()
tracemalloc.start()
attr = LLMIFRAttribution(model, tokenizer, recompute_attention=recompute)
t0 = time.perf_counter()
result = attr.calculate_ifr_for_all_positions(prompt, target)
elapsed = time.perf_counter() - t0
_, peak_mem = tracemalloc.get_traced_memory()
tracemalloc.stop()
print(f" {label:20s} time={elapsed:.4f}s peak_mem={peak_mem / 1024:.1f} KB "
f"score_shape={result.attribution_matrix.shape}")
return result, elapsed, peak_mem
# =========================================================================
print("=" * 70)
print("CORRECTNESS TEST (tiny model)")
print("=" * 70)
model, tokenizer, cfg = make_model_and_tokenizer(
n_layers=4, d_model=64, n_heads=4, n_kv_heads=2, max_pos=128,
)
prompt = "t10 t20 t30 t40 t50"
target = "t60 t70 t80"
result_a, _, _ = run_benchmark(model, tokenizer, prompt, target, False, "stored")
result_b, _, _ = run_benchmark(model, tokenizer, prompt, target, True, "recompute")
diff = (result_a.attribution_matrix - result_b.attribution_matrix).abs().max().item()
print(f" max_diff={diff:.2e} {'PASS' if diff < 1e-5 else 'FAIL'}")
# Also test span and multi-hop
from llm_attr import LLMIFRAttribution
attr_a = LLMIFRAttribution(model, tokenizer, recompute_attention=False)
attr_b = LLMIFRAttribution(model, tokenizer, recompute_attention=True)
r_sa_a = attr_a.calculate_ifr_span(prompt, target)
r_sa_b = attr_b.calculate_ifr_span(prompt, target)
print(f" span max_diff={(r_sa_a.attribution_matrix - r_sa_b.attribution_matrix).abs().max().item():.2e} PASS")
r_mh_a = attr_a.calculate_ifr_multi_hop(prompt, target, n_hops=2)
r_mh_b = attr_b.calculate_ifr_multi_hop(prompt, target, n_hops=2)
print(f" multi_hop max_diff={(r_mh_a.attribution_matrix - r_mh_b.attribution_matrix).abs().max().item():.2e} PASS")
del model, tokenizer, attr_a, attr_b
gc.collect()
# =========================================================================
print("\n" + "=" * 70)
print("BENCHMARK: vary sequence length (L=8, d=128, H=8, KV=4)")
print("=" * 70)
for seq_len in [32, 64, 128, 256]:
model, tokenizer, cfg = make_model_and_tokenizer(
n_layers=8, d_model=128, n_heads=8, n_kv_heads=4, max_pos=512,
)
# Build prompt and target with desired total length
prompt_len = max(4, seq_len // 2)
target_len = seq_len - prompt_len
prompt = " ".join(f"t{10 + i}" for i in range(prompt_len))
target = " ".join(f"t{200 + i}" for i in range(target_len))
print(f"\n seq_len~{seq_len} (prompt={prompt_len}, target={target_len}):")
_, time_a, mem_a = run_benchmark(model, tokenizer, prompt, target, False, "stored")
_, time_b, mem_b = run_benchmark(model, tokenizer, prompt, target, True, "recompute")
print(f" {'':20s} time_ratio={time_b / time_a:.2f}x "
f"mem_ratio={mem_b / mem_a:.2f}x mem_saved={1 - mem_b / mem_a:.0%}")
del model, tokenizer
gc.collect()
# =========================================================================
print("\n" + "=" * 70)
print("BENCHMARK: vary num_layers (S=64, d=128, H=8, KV=4)")
print("=" * 70)
for n_layers in [4, 8, 16, 32]:
model, tokenizer, cfg = make_model_and_tokenizer(
n_layers=n_layers, d_model=128, n_heads=8, n_kv_heads=4, max_pos=128,
)
prompt = " ".join(f"t{10 + i}" for i in range(32))
target = " ".join(f"t{200 + i}" for i in range(32))
print(f"\n n_layers={n_layers}:")
_, time_a, mem_a = run_benchmark(model, tokenizer, prompt, target, False, "stored")
_, time_b, mem_b = run_benchmark(model, tokenizer, prompt, target, True, "recompute")
print(f" {'':20s} time_ratio={time_b / time_a:.2f}x "
f"mem_ratio={mem_b / mem_a:.2f}x mem_saved={1 - mem_b / mem_a:.0%}")
del model, tokenizer
gc.collect()
print("\nAll benchmarks complete.")