| """ |
| Long context visualization β both models. |
| """ |
| import json |
| import matplotlib.pyplot as plt |
| import matplotlib.ticker as ticker |
| import os |
|
|
| def load_long(model_name): |
| path = os.path.expanduser( |
| f"~/kv-hack/results/{model_name}/long_context_results.json" |
| ) |
| with open(path) as f: |
| return json.load(f) |
|
|
| os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True) |
|
|
| mistral = load_long("mistral-7b") |
| llama = load_long("llama-3-8b") |
|
|
| C_FP16 = "#ef4444" |
| C_UNIFORM = "#f97316" |
| C_MISTRAL = "#22c55e" |
| C_LLAMA = "#3b82f6" |
|
|
| |
| fig, axes = plt.subplots(1, 2, figsize=(18, 7)) |
|
|
| for ax, data, color, title, oom_ctx in [ |
| (axes[0], mistral, C_MISTRAL, "Mistral-7B", None), |
| (axes[1], llama, C_LLAMA, "Llama-3-8B", 32768), |
| ]: |
| valid = [r for r in data["results"] if "mixed_precision_mb" in r] |
| ctx = [r["context_len"] for r in valid] |
| fp16 = [r["fp16_mb"] for r in valid] |
| uni8 = [r["uniform8_mb"] for r in valid] |
| ours = [r["mixed_precision_mb"] for r in valid] |
|
|
| ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline") |
| ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit") |
| ax.plot(ctx, ours, '^-', color=color, linewidth=3, markersize=9, label="Per-Head Mixed (Ours)") |
| ax.fill_between(ctx, fp16, ours, alpha=0.08, color=color) |
|
|
| |
| if oom_ctx: |
| ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5) |
| ax.text(ctx[-1]*0.92, max(fp16)*0.85, |
| "FP16\nOOM β", color=C_FP16, |
| fontweight='bold', fontsize=10, ha='right') |
| |
| ours_32k = ours[-1] * 2 |
| ax.annotate(f"Ours at 32K:\n~{ours_32k:.0f}MB β
", |
| xy=(ctx[-1], ours[-1]), |
| xytext=(ctx[-2], ours[-1]+200), |
| color=color, fontweight='bold', fontsize=9, |
| arrowprops=dict(arrowstyle='->', color=color)) |
|
|
| |
| ax.annotate(f"{fp16[-1]/1024:.1f} GB", |
| xy=(ctx[-1], fp16[-1]), |
| xytext=(-40, 10), textcoords='offset points', |
| color=C_FP16, fontweight='bold', fontsize=9) |
| ax.annotate(f"{ours[-1]/1024:.1f} GB", |
| xy=(ctx[-1], ours[-1]), |
| xytext=(-40, -20), textcoords='offset points', |
| color=color, fontweight='bold', fontsize=9) |
|
|
| ax.set_xlabel("Context Length (tokens)", fontsize=12) |
| ax.set_ylabel("KV Cache Memory (MB)", fontsize=12) |
| ax.set_title(f"{title}\nKV Cache Memory vs Context Length", |
| fontsize=13, fontweight='bold') |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
| ax.set_xticks(ctx) |
| ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx]) |
|
|
| plt.suptitle("Per-Head Mixed-Precision KV Cache β Long Context Benchmark\n" |
| "Llama-3-8B FP16 OOMs at 32K. Our method fits.", |
| fontsize=14, fontweight='bold', y=1.02) |
| plt.tight_layout() |
| plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_both.png"), |
| dpi=150, bbox_inches='tight') |
| print("β
Saved figures/long_context_both.png") |
|
|
|
|
| |
| fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
| |
| all_ctx = [512, 1024, 2048, 4096, 8192, 16384, 32768] |
| |
| m_fp16 = [r["fp16_mb"] for r in mistral["results"] if "fp16_mb" in r] |
| m_ours = [r["mixed_precision_mb"] for r in mistral["results"] |
| if "mixed_precision_mb" in r] |
| m_ctx = [r["context_len"] for r in mistral["results"] |
| if "mixed_precision_mb" in r] |
|
|
| |
| l_valid = [r for r in llama["results"] if "mixed_precision_mb" in r] |
| l_fp16 = [r["fp16_mb"] for r in l_valid] |
| l_ours = [r["mixed_precision_mb"] for r in l_valid] |
| l_ctx = [r["context_len"] for r in l_valid] |
|
|
| |
| mistral_model_mem = 14.5 * 1024 |
| llama_model_mem = 16.0 * 1024 |
| a100_total = 40 * 1024 |
|
|
| ax.axhline(y=a100_total - mistral_model_mem, |
| color='gray', linestyle='--', alpha=0.7, linewidth=2, |
| label=f"A100 headroom (Mistral): {(a100_total-mistral_model_mem)/1024:.0f}GB") |
| ax.axhline(y=a100_total - llama_model_mem, |
| color='gray', linestyle=':', alpha=0.7, linewidth=2, |
| label=f"A100 headroom (Llama): {(a100_total-llama_model_mem)/1024:.0f}GB") |
|
|
| ax.plot(m_ctx, m_fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=7, label="FP16 (Mistral)") |
| ax.plot(m_ctx, m_ours, '^-', color=C_MISTRAL, linewidth=2.5, markersize=7, label="Ours (Mistral)") |
| ax.plot(l_ctx, l_fp16, 'o--', color="#f87171", linewidth=2.5, markersize=7, label="FP16 (Llama)") |
| ax.plot(l_ctx, l_ours, '^--', color=C_LLAMA, linewidth=2.5, markersize=7, label="Ours (Llama)") |
|
|
| |
| ax.annotate("Llama FP16\nOOM here β", |
| xy=(16384, l_fp16[-1]), |
| xytext=(12000, l_fp16[-1]+400), |
| color=C_FP16, fontweight='bold', fontsize=10, |
| arrowprops=dict(arrowstyle='->', color=C_FP16)) |
|
|
| ax.set_xlabel("Context Length (tokens)", fontsize=13) |
| ax.set_ylabel("KV Cache Memory (MB)", fontsize=13) |
| ax.set_title("KV Cache Memory vs GPU Headroom\n" |
| "Our method keeps you under the limit longer", |
| fontsize=14, fontweight='bold') |
| ax.legend(fontsize=10, loc='upper left') |
| ax.grid(True, alpha=0.3) |
| ax.set_xticks(m_ctx) |
| ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"]) |
| plt.tight_layout() |
| plt.savefig(os.path.expanduser("~/kv-hack/figures/oom_story.png"), |
| dpi=150, bbox_inches='tight') |
| print("β
Saved figures/oom_story.png") |
|
|
|
|
| |
| fig, ax = plt.subplots(figsize=(10, 5)) |
|
|
| m_prefill = [r["prefill_ms"] for r in mistral["results"] if "prefill_ms" in r] |
| l_prefill = [r["prefill_ms"] for r in llama["results"] if "prefill_ms" in r] |
|
|
| ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5, |
| markersize=8, label="Mistral-7B") |
| ax.plot(l_ctx, l_prefill, 's-', color=C_LLAMA, linewidth=2.5, |
| markersize=8, label="Llama-3-8B") |
|
|
| for x, y in zip(m_ctx, m_prefill): |
| ax.annotate(f"{y:.0f}ms", xy=(x, y), |
| xytext=(0, 10), textcoords='offset points', |
| ha='center', fontsize=8, color=C_MISTRAL) |
| for x, y in zip(l_ctx, l_prefill): |
| ax.annotate(f"{y:.0f}ms", xy=(x, y), |
| xytext=(0, -18), textcoords='offset points', |
| ha='center', fontsize=8, color=C_LLAMA) |
|
|
| ax.set_xlabel("Context Length (tokens)", fontsize=13) |
| ax.set_ylabel("Prefill Latency (ms)", fontsize=13) |
| ax.set_title("Prefill Latency vs Context Length\nBoth Models", |
| fontsize=14, fontweight='bold') |
| ax.legend(fontsize=11) |
| ax.grid(True, alpha=0.3) |
| ax.set_xticks(m_ctx) |
| ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"]) |
| plt.tight_layout() |
| plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"), |
| dpi=150, bbox_inches='tight') |
| print("β
Saved figures/prefill_latency_both.png") |
|
|
| plt.close('all') |
| print("\nπ All long context graphs saved!") |