Update benchmark.py
Browse files- benchmark.py +32 -138
benchmark.py
CHANGED
|
@@ -4,7 +4,7 @@ Benchmarking, metrics, and proof generation for Enhanced SPG.
|
|
| 4 |
Supports LongBench, NIAH, RULER, SCBench benchmarks.
|
| 5 |
MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
|
| 6 |
ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
|
| 7 |
-
FIXED:
|
| 8 |
"""
|
| 9 |
|
| 10 |
import torch
|
|
@@ -144,16 +144,12 @@ class BenchmarkMetrics:
|
|
| 144 |
self.prefill_time_std = float(np.std(self.prefill_times))
|
| 145 |
self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config)
|
| 146 |
self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0
|
| 147 |
-
else:
|
| 148 |
-
logger.debug("No prefill time data available")
|
| 149 |
|
| 150 |
if self.prefill_peak_memories:
|
| 151 |
memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories]
|
| 152 |
self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb))
|
| 153 |
self.prefill_peak_memory_std_mb = float(np.std(memories_mb))
|
| 154 |
self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config)
|
| 155 |
-
else:
|
| 156 |
-
logger.debug("No prefill memory data available")
|
| 157 |
|
| 158 |
if self.decode_times:
|
| 159 |
self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000)
|
|
@@ -162,8 +158,6 @@ class BenchmarkMetrics:
|
|
| 162 |
self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0
|
| 163 |
self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000)
|
| 164 |
self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000)
|
| 165 |
-
else:
|
| 166 |
-
logger.debug("No decode time data available")
|
| 167 |
|
| 168 |
# Calculate end-to-end throughput
|
| 169 |
if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0:
|
|
@@ -174,37 +168,23 @@ class BenchmarkMetrics:
|
|
| 174 |
|
| 175 |
if self.decode_peak_memories:
|
| 176 |
self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024))
|
| 177 |
-
else:
|
| 178 |
-
logger.debug("No decode memory data available")
|
| 179 |
|
| 180 |
if self.prefill_perplexities:
|
| 181 |
self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities))
|
| 182 |
self.prefill_perplexity_std = float(np.std(self.prefill_perplexities))
|
| 183 |
self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config)
|
| 184 |
-
logger.info(f"Calculated prefill perplexity: mean={self.prefill_perplexity_mean:.2f}, "
|
| 185 |
-
f"std={self.prefill_perplexity_std:.2f}, samples={len(self.prefill_perplexities)}")
|
| 186 |
-
else:
|
| 187 |
-
logger.warning("No prefill perplexity data available")
|
| 188 |
|
| 189 |
if self.generation_perplexities:
|
| 190 |
self.generation_perplexity_mean = float(np.mean(self.generation_perplexities))
|
| 191 |
self.generation_perplexity_std = float(np.std(self.generation_perplexities))
|
| 192 |
self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config)
|
| 193 |
-
logger.info(f"Calculated generation perplexity: mean={self.generation_perplexity_mean:.2f}, "
|
| 194 |
-
f"std={self.generation_perplexity_std:.2f}, samples={len(self.generation_perplexities)}")
|
| 195 |
-
else:
|
| 196 |
-
logger.warning("No generation perplexity data available")
|
| 197 |
|
| 198 |
if self.compression_ratios:
|
| 199 |
self.compression_ratio_mean = float(np.mean(self.compression_ratios))
|
| 200 |
self.compression_ratio_std = float(np.std(self.compression_ratios))
|
| 201 |
-
else:
|
| 202 |
-
logger.debug("No compression ratio data available")
|
| 203 |
|
| 204 |
if self.kv_cache_memory_samples_mb:
|
| 205 |
self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb))
|
| 206 |
-
else:
|
| 207 |
-
logger.debug("No KV cache memory data available")
|
| 208 |
|
| 209 |
except Exception as e:
|
| 210 |
logger.error(f"Error calculating statistics: {e}")
|
|
@@ -213,7 +193,6 @@ class BenchmarkMetrics:
|
|
| 213 |
def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]:
|
| 214 |
"""Calculate bootstrap confidence interval with reproducible RNG."""
|
| 215 |
if not data or len(data) < 2:
|
| 216 |
-
logger.warning("Insufficient data for confidence interval calculation")
|
| 217 |
return (0.0, 0.0)
|
| 218 |
|
| 219 |
try:
|
|
@@ -240,11 +219,9 @@ class BenchmarkMetrics:
|
|
| 240 |
|
| 241 |
def safe_tokenize(tokenizer, text, max_length=512):
|
| 242 |
"""Safe tokenization with proper padding and truncation."""
|
| 243 |
-
# Ensure pad_token is set
|
| 244 |
if tokenizer.pad_token is None:
|
| 245 |
tokenizer.pad_token = tokenizer.eos_token
|
| 246 |
|
| 247 |
-
# Tokenize with explicit parameters
|
| 248 |
inputs = tokenizer(
|
| 249 |
text,
|
| 250 |
return_tensors="pt",
|
|
@@ -255,12 +232,10 @@ def safe_tokenize(tokenizer, text, max_length=512):
|
|
| 255 |
add_special_tokens=True
|
| 256 |
)
|
| 257 |
|
| 258 |
-
# Validate outputs
|
| 259 |
if inputs.input_ids.shape[1] == 0:
|
| 260 |
raise ValueError("Tokenization produced empty sequence")
|
| 261 |
|
| 262 |
if inputs.input_ids.shape[1] > max_length:
|
| 263 |
-
logger.warning(f"Sequence length {inputs.input_ids.shape[1]} exceeds max {max_length}")
|
| 264 |
inputs.input_ids = inputs.input_ids[:, :max_length]
|
| 265 |
inputs.attention_mask = inputs.attention_mask[:, :max_length]
|
| 266 |
|
|
@@ -269,41 +244,35 @@ def safe_tokenize(tokenizer, text, max_length=512):
|
|
| 269 |
|
| 270 |
def validate_model_inputs(model, input_ids, attention_mask):
|
| 271 |
"""Validate inputs are compatible with model."""
|
| 272 |
-
# Check sequence length against model's max position embeddings
|
| 273 |
if hasattr(model.config, 'max_position_embeddings'):
|
| 274 |
max_pos = model.config.max_position_embeddings
|
| 275 |
if input_ids.shape[1] > max_pos:
|
| 276 |
-
logger.warning(f"Input length {input_ids.shape[1]} exceeds model max {max_pos}")
|
| 277 |
input_ids = input_ids[:, :max_pos]
|
| 278 |
attention_mask = attention_mask[:, :max_pos]
|
| 279 |
|
| 280 |
-
# For GPT-2, check n_positions
|
| 281 |
if hasattr(model.config, 'n_positions'):
|
| 282 |
n_pos = model.config.n_positions
|
| 283 |
if input_ids.shape[1] > n_pos:
|
| 284 |
-
logger.warning(f"Input length {input_ids.shape[1]} exceeds GPT-2 positions {n_pos}")
|
| 285 |
input_ids = input_ids[:, :n_pos]
|
| 286 |
attention_mask = attention_mask[:, :n_pos]
|
| 287 |
|
| 288 |
-
# Ensure input_ids are within vocabulary range
|
| 289 |
vocab_size = model.config.vocab_size
|
| 290 |
if input_ids.max() >= vocab_size:
|
| 291 |
-
|
|
|
|
|
|
|
| 292 |
input_ids = input_ids.clamp(0, vocab_size - 1)
|
| 293 |
|
| 294 |
return input_ids, attention_mask
|
| 295 |
|
| 296 |
|
| 297 |
def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20):
|
| 298 |
-
"""Safe generation with proper error handling."""
|
| 299 |
try:
|
| 300 |
-
# Validate inputs
|
| 301 |
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
|
| 302 |
|
| 303 |
-
# Set generation config
|
| 304 |
gen_config = {
|
| 305 |
"max_new_tokens": max_new_tokens,
|
| 306 |
-
"temperature": 0.7,
|
| 307 |
"do_sample": False,
|
| 308 |
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 309 |
"eos_token_id": tokenizer.eos_token_id,
|
|
@@ -311,20 +280,21 @@ def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=N
|
|
| 311 |
"use_cache": True
|
| 312 |
}
|
| 313 |
|
| 314 |
-
# Add past_key_values if available
|
| 315 |
if past_key_values is not None:
|
| 316 |
gen_config["past_key_values"] = past_key_values
|
| 317 |
|
| 318 |
-
# Generate with error handling
|
| 319 |
with torch.no_grad():
|
| 320 |
output = model.generate(input_ids, **gen_config)
|
| 321 |
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
except Exception as e:
|
| 325 |
logger.error(f"Generation failed: {e}")
|
| 326 |
-
# Return
|
| 327 |
-
return
|
| 328 |
|
| 329 |
|
| 330 |
def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
@@ -336,21 +306,17 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
| 336 |
"""
|
| 337 |
device = input_ids.device
|
| 338 |
|
| 339 |
-
# Validate inputs first
|
| 340 |
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
|
| 341 |
|
| 342 |
-
# Clear GPU cache if requested
|
| 343 |
if torch.cuda.is_available() and measure_memory:
|
| 344 |
torch.cuda.empty_cache()
|
| 345 |
torch.cuda.reset_peak_memory_stats()
|
| 346 |
torch.cuda.synchronize()
|
| 347 |
|
| 348 |
-
# Measure prefill time
|
| 349 |
if torch.cuda.is_available():
|
| 350 |
torch.cuda.synchronize()
|
| 351 |
start_time = time.perf_counter()
|
| 352 |
|
| 353 |
-
# Prefill phase with error handling
|
| 354 |
try:
|
| 355 |
with torch.inference_mode():
|
| 356 |
outputs = model(
|
|
@@ -363,7 +329,6 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
| 363 |
logits = outputs.logits
|
| 364 |
except Exception as e:
|
| 365 |
logger.error(f"Prefill failed: {e}")
|
| 366 |
-
# Return minimal valid result
|
| 367 |
return {
|
| 368 |
'past_key_values': None,
|
| 369 |
'prefill_time': 0,
|
|
@@ -380,22 +345,18 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
| 380 |
|
| 381 |
prefill_time = time.perf_counter() - start_time
|
| 382 |
|
| 383 |
-
# Measure peak memory
|
| 384 |
prefill_peak_mem = 0
|
| 385 |
if torch.cuda.is_available() and measure_memory:
|
| 386 |
prefill_peak_mem = _peak_mem_bytes_all_gpus()
|
| 387 |
|
| 388 |
-
# Calculate prefill perplexity safely
|
| 389 |
prefill_loss = None
|
| 390 |
if logits is not None and input_ids.shape[1] > 1:
|
| 391 |
try:
|
| 392 |
-
# Ensure we have valid shapes
|
| 393 |
seq_len = min(logits.shape[1], input_ids.shape[1] - 1)
|
| 394 |
if seq_len > 0:
|
| 395 |
shift_logits = logits[:, :seq_len, :].contiguous()
|
| 396 |
shift_labels = input_ids[:, 1:seq_len+1].contiguous()
|
| 397 |
|
| 398 |
-
# Calculate loss with ignore_index for padding
|
| 399 |
loss = F.cross_entropy(
|
| 400 |
shift_logits.view(-1, shift_logits.size(-1)),
|
| 401 |
shift_labels.view(-1),
|
|
@@ -406,30 +367,28 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
| 406 |
except Exception as e:
|
| 407 |
logger.warning(f"Could not calculate prefill loss: {e}")
|
| 408 |
|
| 409 |
-
# Compression phase - same as WikiText
|
| 410 |
original_cache_size = 0
|
| 411 |
compressed_cache_size = 0
|
| 412 |
compression_ratio = 1.0
|
| 413 |
|
| 414 |
if past_key_values:
|
| 415 |
try:
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
| 418 |
|
| 419 |
-
# Calculate original size
|
| 420 |
for layer_idx, (keys, values) in enumerate(kv_tuple):
|
| 421 |
if keys is not None and values is not None:
|
| 422 |
original_cache_size += keys.nelement() * keys.element_size()
|
| 423 |
original_cache_size += values.nelement() * values.element_size()
|
| 424 |
|
| 425 |
-
# Apply compression if enabled
|
| 426 |
if config.compression_type != CompressionType.NONE and cache_manager is not None:
|
| 427 |
try:
|
| 428 |
cache_manager.compress_and_store(layer_idx, keys, values)
|
| 429 |
except Exception as e:
|
| 430 |
logger.error(f"Compression failed for layer {layer_idx}: {e}")
|
| 431 |
|
| 432 |
-
# Reconstruct compressed cache
|
| 433 |
if config.compression_type != CompressionType.NONE and cache_manager is not None:
|
| 434 |
reconstructed_kv = []
|
| 435 |
for layer_idx in range(len(kv_tuple)):
|
|
@@ -438,20 +397,16 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
| 438 |
if dec_keys is not None and dec_values is not None:
|
| 439 |
reconstructed_kv.append((dec_keys, dec_values))
|
| 440 |
else:
|
| 441 |
-
# Use original if decompression fails
|
| 442 |
-
logger.warning(f"Decompression returned None for layer {layer_idx}, using original")
|
| 443 |
reconstructed_kv.append(kv_tuple[layer_idx])
|
| 444 |
except Exception as e:
|
| 445 |
logger.error(f"Decompression failed for layer {layer_idx}: {e}")
|
| 446 |
reconstructed_kv.append(kv_tuple[layer_idx])
|
| 447 |
|
| 448 |
-
# Convert back to DynamicCache format
|
| 449 |
if hasattr(DynamicCache, 'from_legacy_cache'):
|
| 450 |
past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
|
| 451 |
else:
|
| 452 |
past_key_values = tuple(reconstructed_kv)
|
| 453 |
|
| 454 |
-
# Measure compressed size
|
| 455 |
try:
|
| 456 |
compressed_cache_size = cache_manager.get_memory_footprint()
|
| 457 |
except:
|
|
@@ -459,8 +414,8 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
| 459 |
else:
|
| 460 |
compressed_cache_size = original_cache_size
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
|
| 465 |
except Exception as e:
|
| 466 |
logger.error(f"Cache processing failed: {e}")
|
|
@@ -481,7 +436,6 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
| 481 |
|
| 482 |
def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
|
| 483 |
"""Create Needle-in-a-Haystack test context - NO HARDCODING."""
|
| 484 |
-
# Generate haystack text
|
| 485 |
haystack_template = "The quick brown fox jumps over the lazy dog. " * 20
|
| 486 |
haystack_chunks = []
|
| 487 |
|
|
@@ -490,7 +444,6 @@ def create_niah_haystack(context_length: int, needle: str, depth_percent: float)
|
|
| 490 |
|
| 491 |
haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10]
|
| 492 |
|
| 493 |
-
# Insert needle at specified depth
|
| 494 |
insertion_point = int(len(haystack) * depth_percent / 100)
|
| 495 |
haystack_with_needle = (
|
| 496 |
haystack[:insertion_point] +
|
|
@@ -511,25 +464,19 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
|
|
| 511 |
|
| 512 |
prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
|
| 513 |
|
| 514 |
-
# Use safe tokenization
|
| 515 |
inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
|
| 516 |
input_ids = inputs.input_ids.to(model.device)
|
| 517 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 518 |
|
| 519 |
-
# Apply SAME compression pipeline as WikiText
|
| 520 |
compression_result = apply_compression_pipeline(
|
| 521 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 522 |
)
|
| 523 |
|
| 524 |
-
# Generate with compressed cache using safe generation
|
| 525 |
gen_start = time.perf_counter()
|
| 526 |
-
|
| 527 |
-
|
| 528 |
gen_time = time.perf_counter() - gen_start
|
| 529 |
|
| 530 |
-
generated_text = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
|
| 531 |
-
|
| 532 |
-
# Check if needle was retrieved
|
| 533 |
accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
|
| 534 |
|
| 535 |
logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
|
|
@@ -547,10 +494,8 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
|
|
| 547 |
|
| 548 |
def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
|
| 549 |
"""Evaluate RULER with SAME compression pipeline as WikiText."""
|
| 550 |
-
|
| 551 |
-
seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024) # Cap at GPT-2 limit
|
| 552 |
|
| 553 |
-
# Create a retrieval task with multiple facts
|
| 554 |
facts = []
|
| 555 |
for i in range(10):
|
| 556 |
facts.append(f"Fact {i}: The capital of Country{i} is City{i}.")
|
|
@@ -565,20 +510,15 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: O
|
|
| 565 |
input_ids = inputs.input_ids.to(model.device)
|
| 566 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 567 |
|
| 568 |
-
# Apply SAME compression pipeline as WikiText
|
| 569 |
compression_result = apply_compression_pipeline(
|
| 570 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 571 |
)
|
| 572 |
|
| 573 |
-
# Generate with compressed cache
|
| 574 |
gen_start = time.perf_counter()
|
| 575 |
-
|
| 576 |
-
|
| 577 |
gen_time = time.perf_counter() - gen_start
|
| 578 |
|
| 579 |
-
generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
|
| 580 |
-
|
| 581 |
-
# Check exact match
|
| 582 |
expected = f"City{query_idx}"
|
| 583 |
exact_match = 1.0 if expected in generated else 0.0
|
| 584 |
|
|
@@ -597,7 +537,6 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: O
|
|
| 597 |
|
| 598 |
def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
|
| 599 |
"""Evaluate SCBench with SAME compression pipeline as WikiText."""
|
| 600 |
-
# Create multi-turn conversation
|
| 601 |
conversation = []
|
| 602 |
facts = {}
|
| 603 |
|
|
@@ -612,7 +551,6 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager:
|
|
| 612 |
conversation.append(f"User: {user_msg}")
|
| 613 |
conversation.append(f"Assistant: {assistant_msg}")
|
| 614 |
|
| 615 |
-
# Query a random fact
|
| 616 |
query_key = random.choice(list(facts.keys()))
|
| 617 |
conversation.append(f"User: What is {query_key}?")
|
| 618 |
|
|
@@ -622,20 +560,15 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager:
|
|
| 622 |
input_ids = inputs.input_ids.to(model.device)
|
| 623 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 624 |
|
| 625 |
-
# Apply SAME compression pipeline as WikiText
|
| 626 |
compression_result = apply_compression_pipeline(
|
| 627 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 628 |
)
|
| 629 |
|
| 630 |
-
# Generate with compressed cache
|
| 631 |
gen_start = time.perf_counter()
|
| 632 |
-
|
| 633 |
-
|
| 634 |
gen_time = time.perf_counter() - gen_start
|
| 635 |
|
| 636 |
-
generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
|
| 637 |
-
|
| 638 |
-
# Check if correct value is recalled
|
| 639 |
expected_value = facts[query_key]
|
| 640 |
accuracy = 1.0 if expected_value in generated else 0.0
|
| 641 |
|
|
@@ -658,7 +591,6 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
|
|
| 658 |
try:
|
| 659 |
dataset = load_dataset("THUDM/LongBench", task, split="test")
|
| 660 |
|
| 661 |
-
# Sample evaluation examples
|
| 662 |
n_samples = min(config.eval_samples, len(dataset))
|
| 663 |
samples = dataset.select(range(n_samples))
|
| 664 |
|
|
@@ -682,21 +614,16 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
|
|
| 682 |
input_ids = inputs.input_ids.to(model.device)
|
| 683 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 684 |
|
| 685 |
-
# Apply SAME compression pipeline as WikiText
|
| 686 |
compression_result = apply_compression_pipeline(
|
| 687 |
model, tokenizer, input_ids, attention_mask, cache_manager, config,
|
| 688 |
-
measure_memory=False
|
| 689 |
)
|
| 690 |
|
| 691 |
-
# Generate with compressed cache
|
| 692 |
gen_start = time.perf_counter()
|
| 693 |
-
|
| 694 |
-
|
| 695 |
gen_time = time.perf_counter() - gen_start
|
| 696 |
|
| 697 |
-
generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
|
| 698 |
-
|
| 699 |
-
# Simple accuracy metric
|
| 700 |
score = 1.0 if str(answer).lower() in generated.lower() else 0.0
|
| 701 |
scores.append(score)
|
| 702 |
compression_ratios.append(compression_result['compression_ratio'])
|
|
@@ -705,7 +632,6 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
|
|
| 705 |
gen_times.append(gen_time)
|
| 706 |
|
| 707 |
avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0
|
| 708 |
-
logger.info(f"LongBench {task} avg compression: {avg_compression:.1f}x")
|
| 709 |
|
| 710 |
return {
|
| 711 |
'accuracy': float(np.mean(scores)),
|
|
@@ -733,15 +659,11 @@ def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
|
|
| 733 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 734 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 735 |
|
| 736 |
-
# FAIL FAST if CUDA required but unavailable
|
| 737 |
if config.fail_on_cpu_fallback and device == "cpu":
|
| 738 |
raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)")
|
| 739 |
|
| 740 |
logger.info(f"Loading model: {model_name}")
|
| 741 |
|
| 742 |
-
# Check if model requires authentication
|
| 743 |
-
model_info = SUPPORTED_MODELS.get(config.model_key, {})
|
| 744 |
-
|
| 745 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 746 |
model_name,
|
| 747 |
trust_remote_code=True
|
|
@@ -750,7 +672,6 @@ def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
|
|
| 750 |
if tokenizer.pad_token is None:
|
| 751 |
tokenizer.pad_token = tokenizer.eos_token
|
| 752 |
|
| 753 |
-
# Model loading with Flash Attention support
|
| 754 |
model_kwargs = {
|
| 755 |
"torch_dtype": dtype,
|
| 756 |
"device_map": "auto" if device == "cuda" else None,
|
|
@@ -758,20 +679,16 @@ def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
|
|
| 758 |
"trust_remote_code": True
|
| 759 |
}
|
| 760 |
|
| 761 |
-
# Try Flash Attention if requested and available
|
| 762 |
if config.use_flash_attention and device == "cuda":
|
| 763 |
try:
|
| 764 |
-
# First try to load with Flash Attention
|
| 765 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 766 |
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 767 |
logger.info("Successfully loaded with Flash Attention 2")
|
| 768 |
except Exception as e:
|
| 769 |
-
|
| 770 |
-
logger.warning(f"Flash Attention not available, using standard attention: {e}")
|
| 771 |
model_kwargs.pop("attn_implementation", None)
|
| 772 |
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 773 |
else:
|
| 774 |
-
# Load without Flash Attention
|
| 775 |
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 776 |
|
| 777 |
model.eval()
|
|
@@ -784,7 +701,6 @@ def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]
|
|
| 784 |
logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
|
| 785 |
|
| 786 |
if config.benchmark_type == "wikitext":
|
| 787 |
-
# Original WikiText loading
|
| 788 |
texts = []
|
| 789 |
min_tokens = config.prefill_length + config.generation_length
|
| 790 |
|
|
@@ -823,7 +739,6 @@ def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]
|
|
| 823 |
raise
|
| 824 |
|
| 825 |
elif config.benchmark_type == "longbench":
|
| 826 |
-
# Load LongBench dataset
|
| 827 |
texts = []
|
| 828 |
if config.benchmark_subset:
|
| 829 |
try:
|
|
@@ -839,7 +754,6 @@ def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]
|
|
| 839 |
raise
|
| 840 |
|
| 841 |
elif config.benchmark_type in ["niah", "ruler", "scbench"]:
|
| 842 |
-
# These benchmarks generate synthetic data
|
| 843 |
texts = ["Synthetic benchmark data"] * config.eval_samples
|
| 844 |
|
| 845 |
else:
|
|
@@ -858,7 +772,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 858 |
logger.info(f"Benchmark type: {config.benchmark_type}")
|
| 859 |
logger.info(f"Config hash: {config.get_hash()}")
|
| 860 |
|
| 861 |
-
# Enable synchronous CUDA for debugging
|
| 862 |
if torch.cuda.is_available():
|
| 863 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 864 |
|
|
@@ -876,7 +789,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 876 |
logger.error(f"Failed to detect model layers: {e}")
|
| 877 |
raise
|
| 878 |
|
| 879 |
-
# Warmup
|
| 880 |
device = model.device
|
| 881 |
with torch.inference_mode():
|
| 882 |
dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device)
|
|
@@ -899,13 +811,10 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 899 |
|
| 900 |
metrics = BenchmarkMetrics()
|
| 901 |
|
| 902 |
-
# Run benchmark-specific evaluation with UNIFIED compression
|
| 903 |
if config.benchmark_type == "niah":
|
| 904 |
-
# NIAH evaluation with unified compression
|
| 905 |
for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
|
| 906 |
config.niah_depth_percent = depth
|
| 907 |
for idx in range(min(config.eval_samples, 10)):
|
| 908 |
-
# Create cache manager for compression types
|
| 909 |
if config.compression_type != CompressionType.NONE:
|
| 910 |
cache_manager = QuantizedKVCache(config)
|
| 911 |
cache_manager.n_layers = n_layers
|
|
@@ -918,12 +827,11 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 918 |
metrics.compression_ratios.append(result['compression_ratio'])
|
| 919 |
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
|
| 920 |
metrics.prefill_times.append(result['prefill_time'])
|
| 921 |
-
metrics.decode_times.append(result['generation_time'] / 20)
|
| 922 |
|
| 923 |
if result['prefill_peak_mem'] > 0:
|
| 924 |
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
|
| 925 |
|
| 926 |
-
# Record per-sample data
|
| 927 |
per_sample_records.append({
|
| 928 |
'benchmark': 'niah',
|
| 929 |
'depth_percent': depth,
|
|
@@ -935,7 +843,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 935 |
})
|
| 936 |
|
| 937 |
elif config.benchmark_type == "ruler":
|
| 938 |
-
# RULER evaluation with unified compression
|
| 939 |
for idx in range(config.eval_samples):
|
| 940 |
if config.compression_type != CompressionType.NONE:
|
| 941 |
cache_manager = QuantizedKVCache(config)
|
|
@@ -949,7 +856,7 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 949 |
metrics.compression_ratios.append(result['compression_ratio'])
|
| 950 |
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
|
| 951 |
metrics.prefill_times.append(result['prefill_time'])
|
| 952 |
-
metrics.decode_times.append(result['generation_time'] / 10)
|
| 953 |
|
| 954 |
if result['prefill_peak_mem'] > 0:
|
| 955 |
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
|
|
@@ -964,7 +871,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 964 |
})
|
| 965 |
|
| 966 |
elif config.benchmark_type == "scbench":
|
| 967 |
-
# SCBench evaluation with unified compression
|
| 968 |
for idx in range(config.eval_samples):
|
| 969 |
if config.compression_type != CompressionType.NONE:
|
| 970 |
cache_manager = QuantizedKVCache(config)
|
|
@@ -978,7 +884,7 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 978 |
metrics.compression_ratios.append(result['compression_ratio'])
|
| 979 |
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
|
| 980 |
metrics.prefill_times.append(result['prefill_time'])
|
| 981 |
-
metrics.decode_times.append(result['generation_time'] / 20)
|
| 982 |
|
| 983 |
if result['prefill_peak_mem'] > 0:
|
| 984 |
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
|
|
@@ -993,7 +899,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 993 |
})
|
| 994 |
|
| 995 |
elif config.benchmark_type == "longbench":
|
| 996 |
-
# LongBench evaluation with unified compression
|
| 997 |
if config.benchmark_subset:
|
| 998 |
if config.compression_type != CompressionType.NONE:
|
| 999 |
cache_manager = QuantizedKVCache(config)
|
|
@@ -1010,7 +915,7 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1010 |
metrics.prefill_times.append(result['prefill_time'])
|
| 1011 |
|
| 1012 |
if result['generation_time'] > 0:
|
| 1013 |
-
metrics.decode_times.append(result['generation_time'] / 50)
|
| 1014 |
|
| 1015 |
per_sample_records.append({
|
| 1016 |
'benchmark': 'longbench',
|
|
@@ -1022,7 +927,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1022 |
})
|
| 1023 |
|
| 1024 |
else:
|
| 1025 |
-
# Standard WikiText perplexity evaluation with existing compression
|
| 1026 |
for idx in range(config.eval_samples):
|
| 1027 |
logger.info(f"Sample {idx+1}/{config.eval_samples}")
|
| 1028 |
|
|
@@ -1036,12 +940,10 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1036 |
else:
|
| 1037 |
cache_manager = None
|
| 1038 |
|
| 1039 |
-
# Use safe tokenization
|
| 1040 |
inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024))
|
| 1041 |
input_ids = inputs.input_ids.to(device)
|
| 1042 |
attention_mask = inputs.attention_mask.to(device)
|
| 1043 |
|
| 1044 |
-
# Apply unified compression pipeline
|
| 1045 |
compression_result = apply_compression_pipeline(
|
| 1046 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 1047 |
)
|
|
@@ -1057,7 +959,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1057 |
prefill_perplexity = np.exp(compression_result['prefill_loss'])
|
| 1058 |
metrics.prefill_perplexities.append(min(prefill_perplexity, 1000))
|
| 1059 |
|
| 1060 |
-
# Generation phase with timing
|
| 1061 |
generated_ids = input_ids.clone()
|
| 1062 |
decode_times = []
|
| 1063 |
generation_losses = []
|
|
@@ -1110,7 +1011,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1110 |
metrics.calculate_statistics(config)
|
| 1111 |
all_metrics.append(metrics)
|
| 1112 |
|
| 1113 |
-
# Aggregate results across seeds
|
| 1114 |
final_metrics = BenchmarkMetrics()
|
| 1115 |
for m in all_metrics:
|
| 1116 |
final_metrics.prefill_times.extend(m.prefill_times)
|
|
@@ -1128,7 +1028,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1128 |
|
| 1129 |
final_metrics.calculate_statistics(config)
|
| 1130 |
|
| 1131 |
-
# Summary
|
| 1132 |
end_time = datetime.now().isoformat()
|
| 1133 |
summary = {
|
| 1134 |
'compression_type': config.compression_type.value,
|
|
@@ -1142,7 +1041,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1142 |
'end_time': end_time
|
| 1143 |
}
|
| 1144 |
|
| 1145 |
-
# Add benchmark-specific metrics
|
| 1146 |
if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy:
|
| 1147 |
summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy))
|
| 1148 |
elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match:
|
|
@@ -1155,7 +1053,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
|
|
| 1155 |
summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
|
| 1156 |
summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
|
| 1157 |
|
| 1158 |
-
# Always add timing and memory metrics
|
| 1159 |
summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
|
| 1160 |
summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
|
| 1161 |
summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
|
|
@@ -1253,7 +1150,6 @@ def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: Pr
|
|
| 1253 |
recomputed = {}
|
| 1254 |
failures = []
|
| 1255 |
|
| 1256 |
-
# Verify based on benchmark type
|
| 1257 |
if config.benchmark_type == "niah":
|
| 1258 |
if "niah_accuracy" in summary:
|
| 1259 |
recomputed["niah_accuracy"] = mean_of("accuracy")
|
|
@@ -1267,13 +1163,11 @@ def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: Pr
|
|
| 1267 |
if "longbench_accuracy" in summary:
|
| 1268 |
recomputed["longbench_accuracy"] = mean_of("accuracy")
|
| 1269 |
elif config.benchmark_type == "wikitext":
|
| 1270 |
-
# WikiText benchmark metrics
|
| 1271 |
if "prefill_perplexity" in summary:
|
| 1272 |
recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
|
| 1273 |
if "generation_perplexity" in summary:
|
| 1274 |
recomputed["generation_perplexity"] = mean_of("generation_perplexity")
|
| 1275 |
|
| 1276 |
-
# Always verify compression metrics
|
| 1277 |
recomputed["compression_ratio"] = mean_of("compression_ratio")
|
| 1278 |
recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
|
| 1279 |
|
|
|
|
| 4 |
Supports LongBench, NIAH, RULER, SCBench benchmarks.
|
| 5 |
MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
|
| 6 |
ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
|
| 7 |
+
FIXED: Generation errors, proper fallback handling.
|
| 8 |
"""
|
| 9 |
|
| 10 |
import torch
|
|
|
|
| 144 |
self.prefill_time_std = float(np.std(self.prefill_times))
|
| 145 |
self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config)
|
| 146 |
self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0
|
|
|
|
|
|
|
| 147 |
|
| 148 |
if self.prefill_peak_memories:
|
| 149 |
memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories]
|
| 150 |
self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb))
|
| 151 |
self.prefill_peak_memory_std_mb = float(np.std(memories_mb))
|
| 152 |
self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config)
|
|
|
|
|
|
|
| 153 |
|
| 154 |
if self.decode_times:
|
| 155 |
self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000)
|
|
|
|
| 158 |
self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0
|
| 159 |
self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000)
|
| 160 |
self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000)
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Calculate end-to-end throughput
|
| 163 |
if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0:
|
|
|
|
| 168 |
|
| 169 |
if self.decode_peak_memories:
|
| 170 |
self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024))
|
|
|
|
|
|
|
| 171 |
|
| 172 |
if self.prefill_perplexities:
|
| 173 |
self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities))
|
| 174 |
self.prefill_perplexity_std = float(np.std(self.prefill_perplexities))
|
| 175 |
self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
if self.generation_perplexities:
|
| 178 |
self.generation_perplexity_mean = float(np.mean(self.generation_perplexities))
|
| 179 |
self.generation_perplexity_std = float(np.std(self.generation_perplexities))
|
| 180 |
self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
if self.compression_ratios:
|
| 183 |
self.compression_ratio_mean = float(np.mean(self.compression_ratios))
|
| 184 |
self.compression_ratio_std = float(np.std(self.compression_ratios))
|
|
|
|
|
|
|
| 185 |
|
| 186 |
if self.kv_cache_memory_samples_mb:
|
| 187 |
self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb))
|
|
|
|
|
|
|
| 188 |
|
| 189 |
except Exception as e:
|
| 190 |
logger.error(f"Error calculating statistics: {e}")
|
|
|
|
| 193 |
def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]:
|
| 194 |
"""Calculate bootstrap confidence interval with reproducible RNG."""
|
| 195 |
if not data or len(data) < 2:
|
|
|
|
| 196 |
return (0.0, 0.0)
|
| 197 |
|
| 198 |
try:
|
|
|
|
| 219 |
|
| 220 |
def safe_tokenize(tokenizer, text, max_length=512):
|
| 221 |
"""Safe tokenization with proper padding and truncation."""
|
|
|
|
| 222 |
if tokenizer.pad_token is None:
|
| 223 |
tokenizer.pad_token = tokenizer.eos_token
|
| 224 |
|
|
|
|
| 225 |
inputs = tokenizer(
|
| 226 |
text,
|
| 227 |
return_tensors="pt",
|
|
|
|
| 232 |
add_special_tokens=True
|
| 233 |
)
|
| 234 |
|
|
|
|
| 235 |
if inputs.input_ids.shape[1] == 0:
|
| 236 |
raise ValueError("Tokenization produced empty sequence")
|
| 237 |
|
| 238 |
if inputs.input_ids.shape[1] > max_length:
|
|
|
|
| 239 |
inputs.input_ids = inputs.input_ids[:, :max_length]
|
| 240 |
inputs.attention_mask = inputs.attention_mask[:, :max_length]
|
| 241 |
|
|
|
|
| 244 |
|
| 245 |
def validate_model_inputs(model, input_ids, attention_mask):
|
| 246 |
"""Validate inputs are compatible with model."""
|
|
|
|
| 247 |
if hasattr(model.config, 'max_position_embeddings'):
|
| 248 |
max_pos = model.config.max_position_embeddings
|
| 249 |
if input_ids.shape[1] > max_pos:
|
|
|
|
| 250 |
input_ids = input_ids[:, :max_pos]
|
| 251 |
attention_mask = attention_mask[:, :max_pos]
|
| 252 |
|
|
|
|
| 253 |
if hasattr(model.config, 'n_positions'):
|
| 254 |
n_pos = model.config.n_positions
|
| 255 |
if input_ids.shape[1] > n_pos:
|
|
|
|
| 256 |
input_ids = input_ids[:, :n_pos]
|
| 257 |
attention_mask = attention_mask[:, :n_pos]
|
| 258 |
|
|
|
|
| 259 |
vocab_size = model.config.vocab_size
|
| 260 |
if input_ids.max() >= vocab_size:
|
| 261 |
+
input_ids = input_ids.clamp(0, vocab_size - 1)
|
| 262 |
+
|
| 263 |
+
if input_ids.min() < 0:
|
| 264 |
input_ids = input_ids.clamp(0, vocab_size - 1)
|
| 265 |
|
| 266 |
return input_ids, attention_mask
|
| 267 |
|
| 268 |
|
| 269 |
def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20):
|
| 270 |
+
"""Safe generation with proper error handling - returns generated text."""
|
| 271 |
try:
|
|
|
|
| 272 |
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
|
| 273 |
|
|
|
|
| 274 |
gen_config = {
|
| 275 |
"max_new_tokens": max_new_tokens,
|
|
|
|
| 276 |
"do_sample": False,
|
| 277 |
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 278 |
"eos_token_id": tokenizer.eos_token_id,
|
|
|
|
| 280 |
"use_cache": True
|
| 281 |
}
|
| 282 |
|
|
|
|
| 283 |
if past_key_values is not None:
|
| 284 |
gen_config["past_key_values"] = past_key_values
|
| 285 |
|
|
|
|
| 286 |
with torch.no_grad():
|
| 287 |
output = model.generate(input_ids, **gen_config)
|
| 288 |
|
| 289 |
+
# Decode only the generated part
|
| 290 |
+
generated_ids = output[:, input_ids.shape[1]:]
|
| 291 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 292 |
+
return generated_text
|
| 293 |
|
| 294 |
except Exception as e:
|
| 295 |
logger.error(f"Generation failed: {e}")
|
| 296 |
+
# Return empty string on failure
|
| 297 |
+
return ""
|
| 298 |
|
| 299 |
|
| 300 |
def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
|
|
|
|
| 306 |
"""
|
| 307 |
device = input_ids.device
|
| 308 |
|
|
|
|
| 309 |
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
|
| 310 |
|
|
|
|
| 311 |
if torch.cuda.is_available() and measure_memory:
|
| 312 |
torch.cuda.empty_cache()
|
| 313 |
torch.cuda.reset_peak_memory_stats()
|
| 314 |
torch.cuda.synchronize()
|
| 315 |
|
|
|
|
| 316 |
if torch.cuda.is_available():
|
| 317 |
torch.cuda.synchronize()
|
| 318 |
start_time = time.perf_counter()
|
| 319 |
|
|
|
|
| 320 |
try:
|
| 321 |
with torch.inference_mode():
|
| 322 |
outputs = model(
|
|
|
|
| 329 |
logits = outputs.logits
|
| 330 |
except Exception as e:
|
| 331 |
logger.error(f"Prefill failed: {e}")
|
|
|
|
| 332 |
return {
|
| 333 |
'past_key_values': None,
|
| 334 |
'prefill_time': 0,
|
|
|
|
| 345 |
|
| 346 |
prefill_time = time.perf_counter() - start_time
|
| 347 |
|
|
|
|
| 348 |
prefill_peak_mem = 0
|
| 349 |
if torch.cuda.is_available() and measure_memory:
|
| 350 |
prefill_peak_mem = _peak_mem_bytes_all_gpus()
|
| 351 |
|
|
|
|
| 352 |
prefill_loss = None
|
| 353 |
if logits is not None and input_ids.shape[1] > 1:
|
| 354 |
try:
|
|
|
|
| 355 |
seq_len = min(logits.shape[1], input_ids.shape[1] - 1)
|
| 356 |
if seq_len > 0:
|
| 357 |
shift_logits = logits[:, :seq_len, :].contiguous()
|
| 358 |
shift_labels = input_ids[:, 1:seq_len+1].contiguous()
|
| 359 |
|
|
|
|
| 360 |
loss = F.cross_entropy(
|
| 361 |
shift_logits.view(-1, shift_logits.size(-1)),
|
| 362 |
shift_labels.view(-1),
|
|
|
|
| 367 |
except Exception as e:
|
| 368 |
logger.warning(f"Could not calculate prefill loss: {e}")
|
| 369 |
|
|
|
|
| 370 |
original_cache_size = 0
|
| 371 |
compressed_cache_size = 0
|
| 372 |
compression_ratio = 1.0
|
| 373 |
|
| 374 |
if past_key_values:
|
| 375 |
try:
|
| 376 |
+
if hasattr(past_key_values, 'to_legacy_cache'):
|
| 377 |
+
kv_tuple = past_key_values.to_legacy_cache()
|
| 378 |
+
else:
|
| 379 |
+
kv_tuple = past_key_values
|
| 380 |
|
|
|
|
| 381 |
for layer_idx, (keys, values) in enumerate(kv_tuple):
|
| 382 |
if keys is not None and values is not None:
|
| 383 |
original_cache_size += keys.nelement() * keys.element_size()
|
| 384 |
original_cache_size += values.nelement() * values.element_size()
|
| 385 |
|
|
|
|
| 386 |
if config.compression_type != CompressionType.NONE and cache_manager is not None:
|
| 387 |
try:
|
| 388 |
cache_manager.compress_and_store(layer_idx, keys, values)
|
| 389 |
except Exception as e:
|
| 390 |
logger.error(f"Compression failed for layer {layer_idx}: {e}")
|
| 391 |
|
|
|
|
| 392 |
if config.compression_type != CompressionType.NONE and cache_manager is not None:
|
| 393 |
reconstructed_kv = []
|
| 394 |
for layer_idx in range(len(kv_tuple)):
|
|
|
|
| 397 |
if dec_keys is not None and dec_values is not None:
|
| 398 |
reconstructed_kv.append((dec_keys, dec_values))
|
| 399 |
else:
|
|
|
|
|
|
|
| 400 |
reconstructed_kv.append(kv_tuple[layer_idx])
|
| 401 |
except Exception as e:
|
| 402 |
logger.error(f"Decompression failed for layer {layer_idx}: {e}")
|
| 403 |
reconstructed_kv.append(kv_tuple[layer_idx])
|
| 404 |
|
|
|
|
| 405 |
if hasattr(DynamicCache, 'from_legacy_cache'):
|
| 406 |
past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
|
| 407 |
else:
|
| 408 |
past_key_values = tuple(reconstructed_kv)
|
| 409 |
|
|
|
|
| 410 |
try:
|
| 411 |
compressed_cache_size = cache_manager.get_memory_footprint()
|
| 412 |
except:
|
|
|
|
| 414 |
else:
|
| 415 |
compressed_cache_size = original_cache_size
|
| 416 |
|
| 417 |
+
if compressed_cache_size > 0:
|
| 418 |
+
compression_ratio = original_cache_size / compressed_cache_size
|
| 419 |
|
| 420 |
except Exception as e:
|
| 421 |
logger.error(f"Cache processing failed: {e}")
|
|
|
|
| 436 |
|
| 437 |
def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
|
| 438 |
"""Create Needle-in-a-Haystack test context - NO HARDCODING."""
|
|
|
|
| 439 |
haystack_template = "The quick brown fox jumps over the lazy dog. " * 20
|
| 440 |
haystack_chunks = []
|
| 441 |
|
|
|
|
| 444 |
|
| 445 |
haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10]
|
| 446 |
|
|
|
|
| 447 |
insertion_point = int(len(haystack) * depth_percent / 100)
|
| 448 |
haystack_with_needle = (
|
| 449 |
haystack[:insertion_point] +
|
|
|
|
| 464 |
|
| 465 |
prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
|
| 466 |
|
|
|
|
| 467 |
inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
|
| 468 |
input_ids = inputs.input_ids.to(model.device)
|
| 469 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 470 |
|
|
|
|
| 471 |
compression_result = apply_compression_pipeline(
|
| 472 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 473 |
)
|
| 474 |
|
|
|
|
| 475 |
gen_start = time.perf_counter()
|
| 476 |
+
generated_text = safe_generate(model, tokenizer, input_ids, attention_mask,
|
| 477 |
+
compression_result['past_key_values'], max_new_tokens=20)
|
| 478 |
gen_time = time.perf_counter() - gen_start
|
| 479 |
|
|
|
|
|
|
|
|
|
|
| 480 |
accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
|
| 481 |
|
| 482 |
logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
|
|
|
|
| 494 |
|
| 495 |
def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
|
| 496 |
"""Evaluate RULER with SAME compression pipeline as WikiText."""
|
| 497 |
+
seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024)
|
|
|
|
| 498 |
|
|
|
|
| 499 |
facts = []
|
| 500 |
for i in range(10):
|
| 501 |
facts.append(f"Fact {i}: The capital of Country{i} is City{i}.")
|
|
|
|
| 510 |
input_ids = inputs.input_ids.to(model.device)
|
| 511 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 512 |
|
|
|
|
| 513 |
compression_result = apply_compression_pipeline(
|
| 514 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 515 |
)
|
| 516 |
|
|
|
|
| 517 |
gen_start = time.perf_counter()
|
| 518 |
+
generated = safe_generate(model, tokenizer, input_ids, attention_mask,
|
| 519 |
+
compression_result['past_key_values'], max_new_tokens=10)
|
| 520 |
gen_time = time.perf_counter() - gen_start
|
| 521 |
|
|
|
|
|
|
|
|
|
|
| 522 |
expected = f"City{query_idx}"
|
| 523 |
exact_match = 1.0 if expected in generated else 0.0
|
| 524 |
|
|
|
|
| 537 |
|
| 538 |
def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
|
| 539 |
"""Evaluate SCBench with SAME compression pipeline as WikiText."""
|
|
|
|
| 540 |
conversation = []
|
| 541 |
facts = {}
|
| 542 |
|
|
|
|
| 551 |
conversation.append(f"User: {user_msg}")
|
| 552 |
conversation.append(f"Assistant: {assistant_msg}")
|
| 553 |
|
|
|
|
| 554 |
query_key = random.choice(list(facts.keys()))
|
| 555 |
conversation.append(f"User: What is {query_key}?")
|
| 556 |
|
|
|
|
| 560 |
input_ids = inputs.input_ids.to(model.device)
|
| 561 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 562 |
|
|
|
|
| 563 |
compression_result = apply_compression_pipeline(
|
| 564 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 565 |
)
|
| 566 |
|
|
|
|
| 567 |
gen_start = time.perf_counter()
|
| 568 |
+
generated = safe_generate(model, tokenizer, input_ids, attention_mask,
|
| 569 |
+
compression_result['past_key_values'], max_new_tokens=20)
|
| 570 |
gen_time = time.perf_counter() - gen_start
|
| 571 |
|
|
|
|
|
|
|
|
|
|
| 572 |
expected_value = facts[query_key]
|
| 573 |
accuracy = 1.0 if expected_value in generated else 0.0
|
| 574 |
|
|
|
|
| 591 |
try:
|
| 592 |
dataset = load_dataset("THUDM/LongBench", task, split="test")
|
| 593 |
|
|
|
|
| 594 |
n_samples = min(config.eval_samples, len(dataset))
|
| 595 |
samples = dataset.select(range(n_samples))
|
| 596 |
|
|
|
|
| 614 |
input_ids = inputs.input_ids.to(model.device)
|
| 615 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 616 |
|
|
|
|
| 617 |
compression_result = apply_compression_pipeline(
|
| 618 |
model, tokenizer, input_ids, attention_mask, cache_manager, config,
|
| 619 |
+
measure_memory=False
|
| 620 |
)
|
| 621 |
|
|
|
|
| 622 |
gen_start = time.perf_counter()
|
| 623 |
+
generated = safe_generate(model, tokenizer, input_ids, attention_mask,
|
| 624 |
+
compression_result['past_key_values'], max_new_tokens=50)
|
| 625 |
gen_time = time.perf_counter() - gen_start
|
| 626 |
|
|
|
|
|
|
|
|
|
|
| 627 |
score = 1.0 if str(answer).lower() in generated.lower() else 0.0
|
| 628 |
scores.append(score)
|
| 629 |
compression_ratios.append(compression_result['compression_ratio'])
|
|
|
|
| 632 |
gen_times.append(gen_time)
|
| 633 |
|
| 634 |
avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0
|
|
|
|
| 635 |
|
| 636 |
return {
|
| 637 |
'accuracy': float(np.mean(scores)),
|
|
|
|
| 659 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 660 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 661 |
|
|
|
|
| 662 |
if config.fail_on_cpu_fallback and device == "cpu":
|
| 663 |
raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)")
|
| 664 |
|
| 665 |
logger.info(f"Loading model: {model_name}")
|
| 666 |
|
|
|
|
|
|
|
|
|
|
| 667 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 668 |
model_name,
|
| 669 |
trust_remote_code=True
|
|
|
|
| 672 |
if tokenizer.pad_token is None:
|
| 673 |
tokenizer.pad_token = tokenizer.eos_token
|
| 674 |
|
|
|
|
| 675 |
model_kwargs = {
|
| 676 |
"torch_dtype": dtype,
|
| 677 |
"device_map": "auto" if device == "cuda" else None,
|
|
|
|
| 679 |
"trust_remote_code": True
|
| 680 |
}
|
| 681 |
|
|
|
|
| 682 |
if config.use_flash_attention and device == "cuda":
|
| 683 |
try:
|
|
|
|
| 684 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 685 |
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 686 |
logger.info("Successfully loaded with Flash Attention 2")
|
| 687 |
except Exception as e:
|
| 688 |
+
logger.warning(f"Flash Attention not available: {e}")
|
|
|
|
| 689 |
model_kwargs.pop("attn_implementation", None)
|
| 690 |
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 691 |
else:
|
|
|
|
| 692 |
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 693 |
|
| 694 |
model.eval()
|
|
|
|
| 701 |
logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
|
| 702 |
|
| 703 |
if config.benchmark_type == "wikitext":
|
|
|
|
| 704 |
texts = []
|
| 705 |
min_tokens = config.prefill_length + config.generation_length
|
| 706 |
|
|
|
|
| 739 |
raise
|
| 740 |
|
| 741 |
elif config.benchmark_type == "longbench":
|
|
|
|
| 742 |
texts = []
|
| 743 |
if config.benchmark_subset:
|
| 744 |
try:
|
|
|
|
| 754 |
raise
|
| 755 |
|
| 756 |
elif config.benchmark_type in ["niah", "ruler", "scbench"]:
|
|
|
|
| 757 |
texts = ["Synthetic benchmark data"] * config.eval_samples
|
| 758 |
|
| 759 |
else:
|
|
|
|
| 772 |
logger.info(f"Benchmark type: {config.benchmark_type}")
|
| 773 |
logger.info(f"Config hash: {config.get_hash()}")
|
| 774 |
|
|
|
|
| 775 |
if torch.cuda.is_available():
|
| 776 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 777 |
|
|
|
|
| 789 |
logger.error(f"Failed to detect model layers: {e}")
|
| 790 |
raise
|
| 791 |
|
|
|
|
| 792 |
device = model.device
|
| 793 |
with torch.inference_mode():
|
| 794 |
dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device)
|
|
|
|
| 811 |
|
| 812 |
metrics = BenchmarkMetrics()
|
| 813 |
|
|
|
|
| 814 |
if config.benchmark_type == "niah":
|
|
|
|
| 815 |
for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
|
| 816 |
config.niah_depth_percent = depth
|
| 817 |
for idx in range(min(config.eval_samples, 10)):
|
|
|
|
| 818 |
if config.compression_type != CompressionType.NONE:
|
| 819 |
cache_manager = QuantizedKVCache(config)
|
| 820 |
cache_manager.n_layers = n_layers
|
|
|
|
| 827 |
metrics.compression_ratios.append(result['compression_ratio'])
|
| 828 |
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
|
| 829 |
metrics.prefill_times.append(result['prefill_time'])
|
| 830 |
+
metrics.decode_times.append(result['generation_time'] / 20)
|
| 831 |
|
| 832 |
if result['prefill_peak_mem'] > 0:
|
| 833 |
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
|
| 834 |
|
|
|
|
| 835 |
per_sample_records.append({
|
| 836 |
'benchmark': 'niah',
|
| 837 |
'depth_percent': depth,
|
|
|
|
| 843 |
})
|
| 844 |
|
| 845 |
elif config.benchmark_type == "ruler":
|
|
|
|
| 846 |
for idx in range(config.eval_samples):
|
| 847 |
if config.compression_type != CompressionType.NONE:
|
| 848 |
cache_manager = QuantizedKVCache(config)
|
|
|
|
| 856 |
metrics.compression_ratios.append(result['compression_ratio'])
|
| 857 |
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
|
| 858 |
metrics.prefill_times.append(result['prefill_time'])
|
| 859 |
+
metrics.decode_times.append(result['generation_time'] / 10)
|
| 860 |
|
| 861 |
if result['prefill_peak_mem'] > 0:
|
| 862 |
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
|
|
|
|
| 871 |
})
|
| 872 |
|
| 873 |
elif config.benchmark_type == "scbench":
|
|
|
|
| 874 |
for idx in range(config.eval_samples):
|
| 875 |
if config.compression_type != CompressionType.NONE:
|
| 876 |
cache_manager = QuantizedKVCache(config)
|
|
|
|
| 884 |
metrics.compression_ratios.append(result['compression_ratio'])
|
| 885 |
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
|
| 886 |
metrics.prefill_times.append(result['prefill_time'])
|
| 887 |
+
metrics.decode_times.append(result['generation_time'] / 20)
|
| 888 |
|
| 889 |
if result['prefill_peak_mem'] > 0:
|
| 890 |
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
|
|
|
|
| 899 |
})
|
| 900 |
|
| 901 |
elif config.benchmark_type == "longbench":
|
|
|
|
| 902 |
if config.benchmark_subset:
|
| 903 |
if config.compression_type != CompressionType.NONE:
|
| 904 |
cache_manager = QuantizedKVCache(config)
|
|
|
|
| 915 |
metrics.prefill_times.append(result['prefill_time'])
|
| 916 |
|
| 917 |
if result['generation_time'] > 0:
|
| 918 |
+
metrics.decode_times.append(result['generation_time'] / 50)
|
| 919 |
|
| 920 |
per_sample_records.append({
|
| 921 |
'benchmark': 'longbench',
|
|
|
|
| 927 |
})
|
| 928 |
|
| 929 |
else:
|
|
|
|
| 930 |
for idx in range(config.eval_samples):
|
| 931 |
logger.info(f"Sample {idx+1}/{config.eval_samples}")
|
| 932 |
|
|
|
|
| 940 |
else:
|
| 941 |
cache_manager = None
|
| 942 |
|
|
|
|
| 943 |
inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024))
|
| 944 |
input_ids = inputs.input_ids.to(device)
|
| 945 |
attention_mask = inputs.attention_mask.to(device)
|
| 946 |
|
|
|
|
| 947 |
compression_result = apply_compression_pipeline(
|
| 948 |
model, tokenizer, input_ids, attention_mask, cache_manager, config
|
| 949 |
)
|
|
|
|
| 959 |
prefill_perplexity = np.exp(compression_result['prefill_loss'])
|
| 960 |
metrics.prefill_perplexities.append(min(prefill_perplexity, 1000))
|
| 961 |
|
|
|
|
| 962 |
generated_ids = input_ids.clone()
|
| 963 |
decode_times = []
|
| 964 |
generation_losses = []
|
|
|
|
| 1011 |
metrics.calculate_statistics(config)
|
| 1012 |
all_metrics.append(metrics)
|
| 1013 |
|
|
|
|
| 1014 |
final_metrics = BenchmarkMetrics()
|
| 1015 |
for m in all_metrics:
|
| 1016 |
final_metrics.prefill_times.extend(m.prefill_times)
|
|
|
|
| 1028 |
|
| 1029 |
final_metrics.calculate_statistics(config)
|
| 1030 |
|
|
|
|
| 1031 |
end_time = datetime.now().isoformat()
|
| 1032 |
summary = {
|
| 1033 |
'compression_type': config.compression_type.value,
|
|
|
|
| 1041 |
'end_time': end_time
|
| 1042 |
}
|
| 1043 |
|
|
|
|
| 1044 |
if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy:
|
| 1045 |
summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy))
|
| 1046 |
elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match:
|
|
|
|
| 1053 |
summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
|
| 1054 |
summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
|
| 1055 |
|
|
|
|
| 1056 |
summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
|
| 1057 |
summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
|
| 1058 |
summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
|
|
|
|
| 1150 |
recomputed = {}
|
| 1151 |
failures = []
|
| 1152 |
|
|
|
|
| 1153 |
if config.benchmark_type == "niah":
|
| 1154 |
if "niah_accuracy" in summary:
|
| 1155 |
recomputed["niah_accuracy"] = mean_of("accuracy")
|
|
|
|
| 1163 |
if "longbench_accuracy" in summary:
|
| 1164 |
recomputed["longbench_accuracy"] = mean_of("accuracy")
|
| 1165 |
elif config.benchmark_type == "wikitext":
|
|
|
|
| 1166 |
if "prefill_perplexity" in summary:
|
| 1167 |
recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
|
| 1168 |
if "generation_perplexity" in summary:
|
| 1169 |
recomputed["generation_perplexity"] = mean_of("generation_perplexity")
|
| 1170 |
|
|
|
|
| 1171 |
recomputed["compression_ratio"] = mean_of("compression_ratio")
|
| 1172 |
recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
|
| 1173 |
|