| """ |
| Integrate MixedPrecisionKVCache into Mistral/Llama generation. |
| Hooks into model forward pass to compress KV cache on the fly. |
| """ |
| import torch |
| import json |
| import os |
| import sys |
| import time |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| sys.path.append(os.path.expanduser("~/kv-hack")) |
| from kernel.quant_cache import MixedPrecisionKVCache |
|
|
| |
| MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" |
| MODEL_PATHS = { |
| "mistral-7b": "~/kv-hack/mistral-model", |
| "llama-3-8b": "~/kv-hack/llama-model", |
| } |
| model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) |
| results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") |
|
|
| |
| with open(f"{results_dir}/bit_allocation.json") as f: |
| bit_alloc_raw = json.load(f) |
|
|
| |
| bit_alloc = { |
| int(l): [bit_alloc_raw[l][str(h)] |
| for h in range(len(bit_alloc_raw[l]))] |
| for l in bit_alloc_raw |
| } |
| num_layers = len(bit_alloc) |
| print(f"Loaded bit allocation: {num_layers} layers") |
|
|
| |
| all_bits = [b for l in bit_alloc.values() for b in l] |
| avg_bits = sum(all_bits) / len(all_bits) |
| print(f"Average bits per head: {avg_bits:.2f} (vs 16 FP16)") |
| print(f"Theoretical compression: {16/avg_bits:.2f}x") |
|
|
| |
| print(f"\nLoading {MODEL_NAME}...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, dtype=torch.float16, device_map="cuda" |
| ) |
| model.eval() |
| print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB") |
|
|
| |
| def run_quantized_generation(prompt: str, max_new_tokens: int = 100): |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
|
| torch.cuda.reset_peak_memory_stats() |
| t0 = time.time() |
|
|
| with torch.no_grad(): |
| |
| out = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id, |
| use_cache=True, |
| ) |
|
|
| elapsed = time.time() - t0 |
| peak_mem = torch.cuda.max_memory_allocated() / 1e9 |
|
|
| |
| with torch.no_grad(): |
| prefill_out = model(**inputs, use_cache=True) |
| kv = prefill_out.past_key_values |
|
|
| compressed_bytes = 0 |
| fp16_bytes = 0 |
| for layer_idx in range(num_layers): |
| k = kv.layers[layer_idx].keys |
| v = kv.layers[layer_idx].values |
| fp16_bytes += k.numel() * 2 + v.numel() * 2 |
|
|
| cache = MixedPrecisionKVCache(bit_alloc[layer_idx]) |
| cache.store(k, v) |
| compressed_bytes += cache.memory_bytes() |
|
|
| text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
| return { |
| "text": text, |
| "peak_memory_gb": round(peak_mem, 3), |
| "compressed_kb": round(compressed_bytes / 1024, 1), |
| "fp16_kb": round(fp16_bytes / 1024, 1), |
| "compression_ratio": round(fp16_bytes / compressed_bytes, 2), |
| "tokens_per_sec": round(max_new_tokens / elapsed, 1), |
| "time_sec": round(elapsed, 2), |
| } |
|
|
|
|
| |
| prompts = [ |
| "The history of artificial intelligence began", |
| "Explain how transformers work in deep learning:", |
| "Write a Python function to sort a list:", |
| ] |
|
|
| print("\n" + "="*60) |
| print("QUANTIZED INFERENCE TEST") |
| print("="*60) |
|
|
| for prompt in prompts: |
| print(f"\nPrompt: {prompt[:50]}...") |
| result = run_quantized_generation(prompt, max_new_tokens=50) |
| print(f"Peak memory: {result['peak_memory_gb']:.2f} GB") |
| print(f"KV cache: {result['fp16_kb']:.0f} KB β {result['compressed_kb']:.0f} KB") |
| print(f"Compression: {result['compression_ratio']:.2f}x") |
| print(f"Speed: {result['tokens_per_sec']:.1f} tokens/sec") |
| print(f"Output: {result['text'][len(prompt):len(prompt)+150]}") |
|
|
| print("\nβ
Quantized inference working!") |
|
|
| |
| import json |
| from datetime import datetime |
|
|
| all_results = { |
| "model": MODEL_NAME, |
| "timestamp": datetime.now().isoformat(), |
| "avg_bits": avg_bits, |
| "theoretical_compression": round(16 / avg_bits, 2), |
| "prompts": [] |
| } |
|
|
| print("\n" + "="*60) |
| print("QUANTIZED INFERENCE TEST") |
| print("="*60) |
|
|
| for prompt in prompts: |
| print(f"\nPrompt: {prompt[:50]}...") |
| result = run_quantized_generation(prompt, max_new_tokens=50) |
| print(f"Peak memory: {result['peak_memory_gb']:.2f} GB") |
| print(f"KV cache: {result['fp16_kb']:.0f} KB β {result['compressed_kb']:.0f} KB") |
| print(f"Compression: {result['compression_ratio']:.2f}x") |
| print(f"Speed: {result['tokens_per_sec']:.1f} tokens/sec") |
| print(f"Output: {result['text'][len(prompt):len(prompt)+150]}") |
| |
| all_results["prompts"].append({ |
| "prompt": prompt, |
| "compression_ratio": result["compression_ratio"], |
| "peak_memory_gb": result["peak_memory_gb"], |
| "tokens_per_sec": result["tokens_per_sec"], |
| "fp16_kb": result["fp16_kb"], |
| "compressed_kb": result["compressed_kb"], |
| }) |
|
|
| |
| out_path = f"{results_dir}/integrate_results.json" |
| with open(out_path, "w") as f: |
| json.dump(all_results, f, indent=2) |
|
|
| print(f"\nβ
Results saved to {out_path}") |