kfoughali commited on
Commit
00b4f4f
·
verified ·
1 Parent(s): 11a45c9

Update compression.py

Browse files
Files changed (1) hide show
  1. compression.py +0 -1052
compression.py CHANGED
@@ -1,1052 +0,0 @@
1
- """
2
- Enhanced SPG compression algorithms with RocketKV-style 450x compression.
3
- NO ESTIMATIONS - only measured values. FAIL FAST on errors.
4
- """
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- import numpy as np
9
- from typing import Tuple, Optional, Dict, Any, List
10
- from dataclasses import replace
11
- import logging
12
-
13
- from config import (
14
- CompressionConfig, EnhancedSPGConfig, CompressionType,
15
- ResearchConstants
16
- )
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- class EnhancedSlidingPrecisionGradient:
21
- """
22
- Research-grade Enhanced SPG with RocketKV-style 450x compression capability.
23
- NO ESTIMATIONS OR HARDCODED VALUES - all parameters from validated config.
24
- """
25
-
26
- def __init__(self, config: EnhancedSPGConfig):
27
- self.config = config
28
- self.constants = ResearchConstants()
29
- self.layer_decay_rates: Optional[List[float]] = None
30
- self.compression_stats: List[Dict[str, Any]] = []
31
-
32
- # Progressive compression state
33
- self.current_compression_ratio = config.initial_compression_ratio if config.enable_progressive else None
34
- self.progressive_step = 0
35
- self.quality_history: List[float] = []
36
-
37
- # Adaptive state
38
- self.adaptive_enabled = config.enable_adaptive
39
- self.decay_adjustment_rate = config.decay_adjustment_rate
40
- self.target_perplexity_delta = config.target_perplexity_delta
41
-
42
- # RocketKV-style adaptive decomposition
43
- self.use_adaptive_decomposition = config.use_adaptive_decomposition
44
- self.use_hybrid_sparse_attention = config.use_hybrid_sparse_attention
45
- self.target_compression_ratio = config.target_compression_ratio
46
-
47
- logger.info(f"Enhanced SPG initialized with {config.magnitude_threshold_mode} magnitude thresholds")
48
- if self.use_hybrid_sparse_attention:
49
- logger.info("RocketKV-style Hybrid Sparse Attention enabled")
50
-
51
- def initialize_layer_decay_rates(self, n_layers: int) -> None:
52
- """Initialize per-layer decay rates with validation."""
53
- if not self.constants.MIN_LAYERS <= n_layers <= self.constants.MAX_LAYERS:
54
- logger.warning(f"n_layers {n_layers} outside typical range [{self.constants.MIN_LAYERS}, {self.constants.MAX_LAYERS}]")
55
-
56
- if self.config.per_layer_decay:
57
- self.layer_decay_rates = [self.config.base_decay_rate] * n_layers
58
- else:
59
- self.layer_decay_rates = [self.config.base_decay_rate] * n_layers
60
-
61
- self.n_layers = n_layers
62
- logger.info(f"Initialized decay rates for {n_layers} layers")
63
-
64
- def update_decay_rate(self, layer_idx: int, quality_metric: float, target_quality: float) -> None:
65
- """Update decay rate for adaptive SPG with proper validation."""
66
- if not self.adaptive_enabled or self.layer_decay_rates is None:
67
- return
68
-
69
- if not 0 <= layer_idx < len(self.layer_decay_rates):
70
- logger.error(f"Invalid layer_idx {layer_idx}, valid range: [0, {len(self.layer_decay_rates)})")
71
- return
72
-
73
- # Validate and clamp inputs
74
- quality_metric = max(0.1, min(1000.0, float(quality_metric)))
75
- target_quality = max(0.1, min(1000.0, float(target_quality)))
76
-
77
- # Compute adjustment
78
- quality_delta = quality_metric - target_quality
79
-
80
- if quality_delta > 0: # Quality worse than target
81
- adjustment = -self.decay_adjustment_rate * (quality_delta / target_quality)
82
- else: # Quality better than target
83
- adjustment = self.decay_adjustment_rate * (abs(quality_delta) / target_quality)
84
-
85
- # Apply with bounds
86
- old_rate = self.layer_decay_rates[layer_idx]
87
- new_rate = max(0.8, min(0.99, old_rate + adjustment))
88
- self.layer_decay_rates[layer_idx] = new_rate
89
-
90
- logger.debug(f"Adaptive SPG Layer {layer_idx}: quality={quality_metric:.3f}, "
91
- f"target={target_quality:.3f}, decay_rate: {old_rate:.3f} → {new_rate:.3f}")
92
-
93
- def compute_magnitude_importance(self, keys: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
94
- """
95
- Compute importance scores based on magnitude statistics.
96
- This is an EXPLICIT magnitude-based proxy, not an estimation.
97
- """
98
- try:
99
- # Compute L2 norm across head dimension for each token
100
- k_norms = keys.norm(dim=-1).mean(dim=1).mean(dim=0) # [seq_len]
101
- v_norms = values.norm(dim=-1).mean(dim=1).mean(dim=0) # [seq_len]
102
-
103
- # Combine key and value magnitudes (explicit formula)
104
- importance_scores = (k_norms + v_norms) / 2.0
105
-
106
- # Normalize to [0, 1] range for consistent thresholding
107
- score_min = importance_scores.min()
108
- score_max = importance_scores.max()
109
-
110
- if score_max > score_min:
111
- importance_scores = (importance_scores - score_min) / (score_max - score_min)
112
- else:
113
- importance_scores = torch.ones_like(importance_scores)
114
-
115
- logger.debug(f"Computed magnitude importance: min={score_min:.6f}, max={score_max:.6f}")
116
- return importance_scores
117
-
118
- except Exception as e:
119
- logger.error(f"Error computing magnitude importance: {e}")
120
- raise
121
-
122
- def estimate_attention_sparsity(self, keys: torch.Tensor, values: torch.Tensor) -> float:
123
- """Estimate attention pattern sparsity for adaptive decomposition. FAIL FAST on error."""
124
- try:
125
- # Compute approximate attention patterns using key-key similarity
126
- k_norm = F.normalize(keys.float(), p=2, dim=-1)
127
- attention_approx = torch.matmul(k_norm, k_norm.transpose(-2, -1))
128
-
129
- # Measure sparsity as fraction of near-zero attention weights
130
- # Use configurable threshold from constants
131
- threshold = self.constants.ATTENTION_SPARSITY_THRESHOLD
132
- sparse_fraction = (attention_approx.abs() < threshold).float().mean().item()
133
-
134
- return sparse_fraction
135
-
136
- except Exception as e:
137
- # FAIL FAST - NO FALLBACK VALUES
138
- logger.error(f"Failed to estimate attention sparsity: {e}")
139
- raise RuntimeError(f"Cannot measure attention sparsity: {e}")
140
-
141
- def adaptive_stage_split(self, target_ratio: float, seq_len: int, sparsity: float) -> Tuple[float, float]:
142
- """RocketKV-style adaptive compression decomposition with explicit parameters."""
143
- # Use explicit formulas from research constants
144
- if sparsity > self.constants.SPARSITY_HIGH_THRESHOLD:
145
- stage1_power = self.constants.SPARSE_STAGE1_POWER
146
- elif sparsity > self.constants.SPARSITY_MEDIUM_THRESHOLD:
147
- stage1_power = self.constants.BALANCED_STAGE1_POWER
148
- else:
149
- stage1_power = self.constants.DENSE_STAGE1_POWER
150
-
151
- stage1_ratio = target_ratio ** stage1_power
152
- stage2_ratio = target_ratio / stage1_ratio
153
-
154
- # Bounds checking with explicit limits from config
155
- stage1_ratio = max(self.config.stage_compression_min, min(self.config.stage_compression_max, stage1_ratio))
156
- stage2_ratio = max(self.config.stage_compression_min, min(self.config.stage_compression_max, stage2_ratio))
157
-
158
- logger.debug(f"Adaptive split: sparsity={sparsity:.3f}, stage1={stage1_ratio:.1f}x, stage2={stage2_ratio:.1f}x")
159
- return stage1_ratio, stage2_ratio
160
-
161
- def snapkv_plus_plus(self, keys: torch.Tensor, values: torch.Tensor,
162
- compression_ratio: float) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
163
- """SnapKV++ with GQA support and adaptive pooling - no hardcoded values."""
164
- batch_size, n_heads, seq_len, head_dim = keys.shape
165
-
166
- # Adaptive kernel size based on sequence length (from config)
167
- kernel_size = self.config.get_adaptive_kernel_size(seq_len)
168
-
169
- # Compute importance scores with adaptive pooling
170
- key_norms = keys.norm(dim=-1) # [batch, heads, seq]
171
- value_norms = values.norm(dim=-1)
172
- combined_importance = (key_norms + value_norms) / 2.0
173
-
174
- # Multi-head aggregation with adaptive pooling
175
- if kernel_size > 1:
176
- # Apply 1D pooling along sequence dimension
177
- pooled_importance = F.avg_pool1d(
178
- combined_importance.mean(dim=1).unsqueeze(1), # [batch, 1, seq]
179
- kernel_size=kernel_size,
180
- stride=1,
181
- padding=kernel_size // 2
182
- ).squeeze(1) # [batch, seq]
183
- # Ensure pooled output matches original sequence length
184
- if pooled_importance.shape[-1] != seq_len:
185
- pooled_importance = pooled_importance[:, :seq_len]
186
- else:
187
- pooled_importance = combined_importance.mean(dim=1)
188
-
189
- # Aggregate across batch
190
- final_importance = pooled_importance.mean(dim=0) # [seq]
191
-
192
- # Ensure importance tensor matches sequence length
193
- if final_importance.shape[0] != seq_len:
194
- final_importance = final_importance[:seq_len]
195
-
196
- # Preserve sink and recent tokens
197
- preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
198
- preserve_mask[:min(self.config.sink_tokens, seq_len)] = True
199
- preserve_mask[-min(self.config.recent_window, seq_len):] = True
200
-
201
- # Top-k selection for remaining tokens
202
- n_keep = max(self.config.sink_tokens + self.config.recent_window,
203
- int(seq_len / compression_ratio))
204
- n_keep = min(n_keep, seq_len) # Ensure we don't exceed sequence length
205
- remaining_slots = n_keep - preserve_mask.sum().item()
206
-
207
- if remaining_slots > 0:
208
- masked_importance = final_importance.clone()
209
- masked_importance[preserve_mask] = -float('inf')
210
-
211
- available_indices = (~preserve_mask).nonzero(as_tuple=True)[0]
212
- if len(available_indices) > 0:
213
- k = min(remaining_slots, len(available_indices))
214
- if k > 0:
215
- _, relative_top_indices = torch.topk(masked_importance[available_indices], k)
216
- absolute_top_indices = available_indices[relative_top_indices]
217
- preserve_mask[absolute_top_indices] = True
218
-
219
- # Extract retained tokens with bounds checking
220
- retained_indices = torch.where(preserve_mask)[0]
221
- retained_indices = retained_indices[retained_indices < seq_len] # Safety check
222
-
223
- keys_compressed = keys[:, :, retained_indices, :]
224
- values_compressed = values[:, :, retained_indices, :]
225
-
226
- actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else float('inf')
227
- logger.debug(f"SnapKV++: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
228
-
229
- return keys_compressed, values_compressed, retained_indices.tolist()
230
-
231
- def hybrid_sparse_attention(self, keys: torch.Tensor, values: torch.Tensor,
232
- head_budget: int, seq_budget: int) -> Dict[str, Any]:
233
- """RocketKV-style Hybrid Sparse Attention for Stage 2 - no hardcoded values."""
234
- batch_size, n_heads, seq_len, head_dim = keys.shape
235
-
236
- # 1. Head-wise importance scoring
237
- head_importance = (
238
- keys.float().pow(2).sum(dim=(-1, -2)).sum(dim=0) + # Sum over batch, seq, hidden
239
- values.float().pow(2).sum(dim=(-1, -2)).sum(dim=0)
240
- ) # [n_heads]
241
-
242
- # Select top heads
243
- actual_head_budget = min(head_budget, n_heads)
244
- _, top_head_indices = torch.topk(head_importance, actual_head_budget)
245
-
246
- compressed_data = {
247
- 'keys': {},
248
- 'values': {},
249
- 'metadata': {
250
- 'head_selection': top_head_indices.tolist(),
251
- 'original_shape': keys.shape,
252
- 'compression_type': 'hybrid_sparse_attention'
253
- }
254
- }
255
-
256
- # 2. Sequence-wise top-k selection per selected head
257
- for head_idx in top_head_indices:
258
- head_keys = keys[:, head_idx:head_idx+1, :, :] # Keep head dimension
259
- head_values = values[:, head_idx:head_idx+1, :, :]
260
-
261
- # Compute sequence importance for this head
262
- seq_importance = (
263
- head_keys.norm(dim=-1).squeeze(1).mean(dim=0) + # [seq]
264
- head_values.norm(dim=-1).squeeze(1).mean(dim=0)
265
- ) / 2.0
266
-
267
- # Apply position-based boost (from research constants)
268
- position_boost = torch.ones_like(seq_importance)
269
- position_boost[:self.config.sink_tokens] *= self.constants.POSITION_BOOST_SINK
270
- position_boost[-self.config.recent_window:] *= self.constants.POSITION_BOOST_RECENT
271
- boosted_importance = seq_importance * position_boost
272
-
273
- # Select top tokens for this head
274
- actual_seq_budget = min(seq_budget, seq_len)
275
- _, top_token_indices = torch.topk(boosted_importance, actual_seq_budget)
276
-
277
- # Store compressed data
278
- head_key = f'head_{head_idx.item()}'
279
- compressed_data['keys'][head_key] = {
280
- 'data': head_keys[:, :, top_token_indices, :].clone(),
281
- 'indices': top_token_indices.tolist()
282
- }
283
- compressed_data['values'][head_key] = {
284
- 'data': head_values[:, :, top_token_indices, :].clone(),
285
- 'indices': top_token_indices.tolist()
286
- }
287
-
288
- return compressed_data
289
-
290
- def stage1_permanent_eviction(self, keys: torch.Tensor, values: torch.Tensor,
291
- layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
292
- """
293
- Stage 1: RocketKV-style permanent eviction with SnapKV++ or magnitude-guided approach.
294
- """
295
- batch_size, n_heads, seq_len, head_dim = keys.shape
296
-
297
- if self.use_adaptive_decomposition:
298
- # Use adaptive compression split
299
- sparsity = self.estimate_attention_sparsity(keys, values) # May raise if fails
300
- stage1_ratio, _ = self.adaptive_stage_split(self.target_compression_ratio, seq_len, sparsity)
301
- else:
302
- stage1_ratio = self.config.stage1_compression_ratio
303
-
304
- # Choose compression method based on configuration
305
- if self.config.use_snapkv_plus_plus:
306
- return self.snapkv_plus_plus(keys, values, stage1_ratio)
307
- else:
308
- # Original magnitude-guided approach
309
- return self._magnitude_guided_stage1(keys, values, layer_idx, stage1_ratio)
310
-
311
- def _magnitude_guided_stage1(self, keys: torch.Tensor, values: torch.Tensor,
312
- layer_idx: int, compression_ratio: float) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
313
- """Original magnitude-guided Stage 1 eviction with explicit parameters."""
314
- batch_size, n_heads, seq_len, head_dim = keys.shape
315
-
316
- # Calculate retention based on compression ratio
317
- retention_ratio = 1.0 / compression_ratio
318
- min_retain = self.config.sink_tokens + self.config.recent_window
319
- n_retain = max(min_retain, int(seq_len * retention_ratio))
320
-
321
- # Apply layer-specific constraints (from research constants)
322
- layer_position = layer_idx / max(getattr(self, 'n_layers', 12) - 1, 1)
323
- if layer_position <= 0.5: # Early layers
324
- max_retain = int(seq_len * self.constants.EARLY_LAYER_MAX_RETENTION)
325
- else: # Late layers
326
- max_retain = int(seq_len * self.constants.LATE_LAYER_MAX_RETENTION)
327
-
328
- n_retain = min(n_retain, max_retain)
329
-
330
- # Compute magnitude-based importance
331
- importance_scores = self.compute_magnitude_importance(keys, values)
332
-
333
- # Quality preservation: boost recent tokens (explicit formula from config)
334
- recent_boost = torch.zeros_like(importance_scores)
335
- if self.config.recent_window > 0:
336
- recent_boost[-self.config.recent_window:] = importance_scores.max() * self.config.recent_boost_factor
337
- importance_scores = importance_scores + recent_boost
338
-
339
- # Initialize preservation mask
340
- preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
341
- preserve_mask[:self.config.sink_tokens] = True
342
- preserve_mask[-self.config.recent_window:] = True
343
-
344
- # Select additional tokens based on importance
345
- remaining_slots = n_retain - preserve_mask.sum().item()
346
- if remaining_slots > 0:
347
- masked_importance = importance_scores.clone()
348
- masked_importance[preserve_mask] = -float('inf')
349
-
350
- # Use configured threshold (not hardcoded)
351
- magnitude_threshold = torch.quantile(
352
- importance_scores.float(),
353
- self.config.get_magnitude_threshold()
354
- )
355
-
356
- below_threshold = masked_importance < magnitude_threshold
357
- masked_importance[below_threshold] = -float('inf')
358
-
359
- available = (masked_importance > -float('inf')).sum().item()
360
- k = min(remaining_slots, available)
361
- if k > 0:
362
- _, top_indices = torch.topk(masked_importance, k)
363
- preserve_mask[top_indices] = True
364
-
365
- # Extract retained tokens
366
- retained_indices = torch.where(preserve_mask)[0]
367
- keys_stage1 = keys[:, :, retained_indices, :]
368
- values_stage1 = values[:, :, retained_indices, :]
369
-
370
- actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else float('inf')
371
- logger.debug(f"Stage 1 Layer {layer_idx}: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
372
-
373
- return keys_stage1, values_stage1, retained_indices.tolist()
374
-
375
- def stage2_multi_dimensional_compression(self, keys: torch.Tensor, values: torch.Tensor,
376
- layer_idx: int, retained_indices: List[int]) -> Dict[str, Any]:
377
- """
378
- Stage 2: RocketKV-style Hybrid Sparse Attention compression.
379
- Uses dynamic top-k selection with head and sequence reductions.
380
- """
381
- batch_size, n_heads, seq_len, head_dim = keys.shape
382
-
383
- if self.use_hybrid_sparse_attention:
384
- # RocketKV-style compression with adaptive budgets
385
- sparsity = self.estimate_attention_sparsity(keys, values) # May raise if fails
386
-
387
- if self.use_adaptive_decomposition:
388
- _, stage2_ratio = self.adaptive_stage_split(
389
- self.target_compression_ratio, seq_len, sparsity
390
- )
391
- else:
392
- stage2_ratio = self.config.stage2_compression_ratio
393
-
394
- # Dynamic budgets based on compression target (from config)
395
- head_retention_ratio = self.config.get_head_retention_ratio()
396
- head_budget = max(1, int(n_heads * head_retention_ratio))
397
- seq_budget = max(self.config.min_tokens_for_stability, int(seq_len / stage2_ratio))
398
-
399
- # Use hybrid sparse attention
400
- compressed_data = self.hybrid_sparse_attention(keys, values, head_budget, seq_budget)
401
-
402
- # Add metadata
403
- compressed_data['metadata'].update({
404
- 'stage1_retained_indices': retained_indices,
405
- 'original_shape_after_stage1': keys.shape,
406
- 'original_dtype': keys.dtype,
407
- 'layer_idx': layer_idx,
408
- 'sparsity_estimate': sparsity,
409
- 'stage2_compression_ratio': stage2_ratio,
410
- 'head_budget': head_budget,
411
- 'seq_budget': seq_budget,
412
- 'head_retention_ratio': head_retention_ratio
413
- })
414
-
415
- return compressed_data
416
-
417
- # Fallback to original multi-dimensional compression
418
- return self._original_stage2_compression(keys, values, layer_idx, retained_indices)
419
-
420
- def _original_stage2_compression(self, keys: torch.Tensor, values: torch.Tensor,
421
- layer_idx: int, retained_indices: List[int]) -> Dict[str, Any]:
422
- """Original Stage 2 implementation for comparison."""
423
- batch_size, n_heads, seq_len, head_dim = keys.shape
424
-
425
- # Compute importance for remaining tokens
426
- importance_scores = self.compute_magnitude_importance(keys, values)
427
-
428
- # Combine with position-based decay (explicit formula)
429
- decay_rate = self.layer_decay_rates[layer_idx] if self.layer_decay_rates else self.config.base_decay_rate
430
- position_scores = torch.pow(
431
- decay_rate,
432
- torch.arange(seq_len, device=keys.device).float() / self.config.decay_normalization
433
- )
434
-
435
- combined_importance = importance_scores * position_scores
436
-
437
- compressed_data = {
438
- 'keys': {},
439
- 'values': {},
440
- 'metadata': {
441
- 'stage1_retained_indices': retained_indices,
442
- 'importance_scores': combined_importance,
443
- 'original_shape_after_stage1': keys.shape,
444
- 'original_dtype': keys.dtype,
445
- 'layer_idx': layer_idx,
446
- 'magnitude_threshold_mode': self.config.magnitude_threshold_mode,
447
- 'compression_type': 'original_multi_dimensional'
448
- }
449
- }
450
-
451
- # Head dimension compression with explicit parameters
452
- if self.config.enable_head_compression:
453
- n_important_heads = max(1, int(n_heads * self.config.head_compression_ratio))
454
-
455
- # UPDATED: Always reserve top head_fp16_reserve heads at full precision
456
- n_reserved_heads = min(getattr(self.config, 'head_fp16_reserve', 2), n_heads)
457
- n_important_heads = max(n_reserved_heads, n_important_heads)
458
-
459
- # Compute head importance (explicit calculation)
460
- head_importance = (
461
- keys.float().pow(2).sum(dim=(-1, -2)).sum(dim=0) +
462
- values.float().pow(2).sum(dim=(-1, -2)).sum(dim=0)
463
- )
464
-
465
- _, important_head_indices = torch.topk(head_importance, n_important_heads)
466
- other_head_indices = torch.tensor(
467
- [h for h in range(n_heads) if h not in important_head_indices.tolist()],
468
- device=keys.device, dtype=torch.long
469
- )
470
-
471
- # Store important heads at full precision
472
- compressed_data['keys']['heads_fp16'] = {
473
- 'data': keys[:, important_head_indices, :, :].clone(),
474
- 'indices': important_head_indices.tolist()
475
- }
476
- compressed_data['values']['heads_fp16'] = {
477
- 'data': values[:, important_head_indices, :, :].clone(),
478
- 'indices': important_head_indices.tolist()
479
- }
480
-
481
- if other_head_indices.numel() == 0:
482
- return compressed_data
483
-
484
- seq_keys = keys[:, other_head_indices, :, :]
485
- seq_values = values[:, other_head_indices, :, :]
486
- else:
487
- seq_keys = keys
488
- seq_values = values
489
-
490
- # Sequence dimension compression with explicit ratios
491
- levels = self.config.precision_levels
492
-
493
- # Explicit top-K selection for FP16
494
- keep_fp16 = max(0, int(seq_len * self.config.sequence_compression_ratio))
495
- top_fp16 = torch.topk(combined_importance, k=keep_fp16).indices if keep_fp16 > 0 else torch.empty(0, dtype=torch.long, device=keys.device)
496
- is_fp16 = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
497
- if keep_fp16 > 0:
498
- is_fp16[top_fp16] = True
499
-
500
- # Vectorized token binning
501
- thresh = torch.tensor([pl.threshold for pl in levels], device=keys.device)
502
- thresh_sorted, order = torch.sort(thresh, descending=True)
503
- level_ids = torch.bucketize(combined_importance, thresh_sorted, right=False)
504
-
505
- # Assign tokens to precision levels
506
- for i in range(seq_len):
507
- if is_fp16[i]:
508
- precision_key = 'seq_fp16'
509
- else:
510
- level_idx = min(level_ids[i].item(), len(levels) - 1)
511
- level = levels[order[level_idx]]
512
-
513
- if level.bits is not None:
514
- precision_key = f'seq_{level.bits}bit'
515
- else:
516
- precision_key = f'seq_{level.name}'
517
-
518
- if precision_key not in compressed_data['keys']:
519
- compressed_data['keys'][precision_key] = {
520
- 'indices': [], 'data': None, 'scale': None, 'zero': None
521
- }
522
- compressed_data['values'][precision_key] = {
523
- 'indices': [], 'data': None, 'scale': None, 'zero': None
524
- }
525
-
526
- compressed_data['keys'][precision_key]['indices'].append(i)
527
- compressed_data['values'][precision_key]['indices'].append(i)
528
-
529
- # Store data with aggressive precision (FP16 for most important tokens)
530
- keys_to_delete = []
531
- for precision_key in list(compressed_data['keys'].keys()):
532
- if not precision_key.startswith('seq_'):
533
- continue
534
-
535
- indices = compressed_data['keys'][precision_key]['indices']
536
- if not indices:
537
- keys_to_delete.append(precision_key)
538
- continue
539
-
540
- if precision_key == 'seq_discard':
541
- keys_to_delete.append(precision_key)
542
- continue
543
-
544
- idx_tensor = torch.tensor(indices, device=keys.device, dtype=torch.long)
545
- k_slice = seq_keys.index_select(2, idx_tensor)
546
- v_slice = seq_values.index_select(2, idx_tensor)
547
-
548
- # Store with aggressive precision - only FP16 for ultra-selective tokens
549
- compressed_data['keys'][precision_key]['data'] = k_slice.clone()
550
- compressed_data['values'][precision_key]['data'] = v_slice.clone()
551
-
552
- # Clean up empty keys
553
- for pk in keys_to_delete:
554
- compressed_data['keys'].pop(pk, None)
555
- compressed_data['values'].pop(pk, None)
556
-
557
- return compressed_data
558
-
559
- def compress_with_enhanced_gradient(self, keys: torch.Tensor, values: torch.Tensor,
560
- layer_idx: int, current_position: int) -> Dict[str, Any]:
561
- """
562
- Main compression function with explicit two-stage approach.
563
- """
564
- if not self.config.enable_two_stage:
565
- return self._fallback_to_original_spg(keys, values, layer_idx, current_position)
566
-
567
- try:
568
- # Record original shape
569
- orig_shape_full = keys.shape
570
-
571
- # Stage 1: Permanent eviction
572
- keys_stage1, values_stage1, retained_indices = self.stage1_permanent_eviction(
573
- keys, values, layer_idx
574
- )
575
-
576
- # Stage 2: Multi-dimensional compression
577
- compressed_data = self.stage2_multi_dimensional_compression(
578
- keys_stage1, values_stage1, layer_idx, retained_indices
579
- )
580
-
581
- # Add metadata
582
- compressed_data['metadata']['original_full_shape'] = orig_shape_full
583
-
584
- # Progressive compression
585
- if self.config.enable_progressive:
586
- compressed_data = self._apply_progressive_compression(compressed_data, layer_idx)
587
-
588
- return compressed_data
589
-
590
- except Exception as e:
591
- logger.error(f"Error in enhanced compression for layer {layer_idx}: {e}")
592
- raise
593
-
594
- def _fallback_to_original_spg(self, keys: torch.Tensor, values: torch.Tensor,
595
- layer_idx: int, current_position: Optional[int]) -> Dict[str, Any]:
596
- """Fallback to original SPG implementation with actual data storage."""
597
- batch_size, n_heads, seq_len, head_dim = keys.shape
598
-
599
- # Original position-based precision computation
600
- device = keys.device
601
- precision_scores = torch.zeros(seq_len, device=device)
602
-
603
- decay_rate = self.layer_decay_rates[layer_idx] if self.layer_decay_rates else self.config.base_decay_rate
604
-
605
- positions = torch.arange(seq_len, device=device)
606
- if current_position is None or not isinstance(current_position, (int, float)):
607
- current_position = seq_len
608
- current_position = int(current_position)
609
- distances = torch.tensor(current_position, device=device, dtype=positions.dtype) - positions
610
-
611
- precision_scores = torch.pow(decay_rate, distances.float() / self.config.decay_normalization)
612
- precision_scores[:self.config.sink_tokens] = 1.0
613
-
614
- recent_mask = distances < self.config.recent_window
615
- precision_scores[recent_mask] = torch.maximum(
616
- precision_scores[recent_mask],
617
- torch.tensor(self.config.recent_min_precision, device=device)
618
- )
619
-
620
- # Apply precision levels with actual data storage
621
- compressed_data = {
622
- 'keys': {},
623
- 'values': {},
624
- 'metadata': {
625
- 'precision_scores': precision_scores,
626
- 'original_shape': keys.shape,
627
- 'original_dtype': keys.dtype,
628
- 'layer_idx': layer_idx,
629
- 'compression_type': 'original_spg'
630
- }
631
- }
632
-
633
- # Exclusive binning for precision levels
634
- levels = self.config.precision_levels
635
- for i, score in enumerate(precision_scores):
636
- for j, level in enumerate(levels):
637
- lo = level.threshold
638
- hi = levels[j-1].threshold if j > 0 else float('inf')
639
-
640
- if lo <= score < hi:
641
- if level.bits is not None:
642
- precision_key = f'{level.bits}bit'
643
- else:
644
- precision_key = level.name
645
-
646
- if precision_key not in compressed_data['keys']:
647
- compressed_data['keys'][precision_key] = {
648
- 'indices': [], 'data': None, 'scale': None, 'zero': None
649
- }
650
- compressed_data['values'][precision_key] = {
651
- 'indices': [], 'data': None, 'scale': None, 'zero': None
652
- }
653
-
654
- compressed_data['keys'][precision_key]['indices'].append(i)
655
- compressed_data['values'][precision_key]['indices'].append(i)
656
- break
657
-
658
- # Process data
659
- keys_to_delete = []
660
- for precision_key in list(compressed_data['keys'].keys()):
661
- indices = compressed_data['keys'][precision_key]['indices']
662
- if not indices:
663
- keys_to_delete.append(precision_key)
664
- continue
665
-
666
- if precision_key == 'discard':
667
- keys_to_delete.append(precision_key)
668
- continue
669
-
670
- level_indices = torch.tensor(indices, device=device, dtype=torch.long)
671
- k_slice = keys.index_select(2, level_indices)
672
- v_slice = values.index_select(2, level_indices)
673
-
674
- # Store with FP16 precision (simplified for original SPG)
675
- compressed_data['keys'][precision_key]['data'] = k_slice.clone()
676
- compressed_data['values'][precision_key]['data'] = v_slice.clone()
677
-
678
- # Clean up empty keys
679
- for pk in keys_to_delete:
680
- compressed_data['keys'].pop(pk, None)
681
- compressed_data['values'].pop(pk, None)
682
-
683
- return compressed_data
684
-
685
- def _apply_progressive_compression(self, compressed_data: Dict, layer_idx: int) -> Dict:
686
- """Apply progressive compression with relative quality change detection."""
687
- if len(self.quality_history) >= self.constants.PROGRESSIVE_QUALITY_WINDOW:
688
- recent = float(np.mean(self.quality_history[-self.constants.PROGRESSIVE_RECENT_WINDOW:]))
689
- prev = float(np.mean(self.quality_history[-self.constants.PROGRESSIVE_QUALITY_WINDOW:-self.constants.PROGRESSIVE_RECENT_WINDOW]))
690
- rel_delta = (recent - prev) / max(prev, 1e-9)
691
-
692
- if rel_delta <= self.config.quality_threshold:
693
- old_ratio = self.current_compression_ratio or self.config.initial_compression_ratio
694
- new_ratio = min(old_ratio * self.config.progression_factor, self.config.max_compression_ratio)
695
-
696
- if new_ratio > old_ratio:
697
- self.current_compression_ratio = new_ratio
698
- compression_factor = new_ratio / old_ratio
699
-
700
- # Tighten compression ratios (use configurable minimum from config)
701
- self.config.head_compression_ratio = max(self.config.progressive_min_ratio,
702
- self.config.head_compression_ratio / compression_factor)
703
- self.config.sequence_compression_ratio = max(self.config.progressive_min_ratio,
704
- self.config.sequence_compression_ratio / compression_factor)
705
-
706
- self.progressive_step += 1
707
-
708
- logger.info(f"Progressive step {self.progressive_step}: rel_delta={rel_delta:.4f}, new_ratio={new_ratio:.1f}x")
709
-
710
- compressed_data['metadata']['progressive_compression_ratio'] = self.current_compression_ratio
711
- compressed_data['metadata']['progressive_step'] = self.progressive_step
712
-
713
- return compressed_data
714
-
715
- def decompress(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
716
- """Decompress enhanced SPG compressed data."""
717
- metadata = compressed_data['metadata']
718
-
719
- if metadata.get('compression_type') == 'original_spg':
720
- return self._decompress_original_spg(compressed_data)
721
-
722
- return self._decompress_enhanced_spg(compressed_data)
723
-
724
- def _decompress_enhanced_spg(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
725
- """Decompress enhanced multi-stage compressed data with HSA support."""
726
- metadata = compressed_data['metadata']
727
-
728
- # Get device from first available tensor
729
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
730
- for storage_type in ['keys', 'values']:
731
- for key, data in compressed_data[storage_type].items():
732
- if isinstance(data, dict) and 'data' in data and isinstance(data['data'], torch.Tensor):
733
- device = data['data'].device
734
- break
735
- if device != torch.device('cuda' if torch.cuda.is_available() else 'cpu'):
736
- break
737
-
738
- # Handle hybrid sparse attention format
739
- if metadata.get('compression_type') == 'hybrid_sparse_attention':
740
- return self._decompress_hybrid_sparse_attention(compressed_data)
741
-
742
- # Original enhanced SPG decompression
743
- original_shape = metadata['original_shape_after_stage1']
744
- original_dtype = metadata['original_dtype']
745
-
746
- keys_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
747
- values_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
748
-
749
- # Decompress head dimension data first
750
- if 'heads_fp16' in compressed_data['keys']:
751
- head_indices = compressed_data['keys']['heads_fp16']['indices']
752
- head_idx_tensor = torch.tensor(head_indices, device=device, dtype=torch.long)
753
- keys_full[:, head_idx_tensor, :, :] = compressed_data['keys']['heads_fp16']['data']
754
- values_full[:, head_idx_tensor, :, :] = compressed_data['values']['heads_fp16']['data']
755
-
756
- if self.config.enable_head_compression:
757
- n_heads = original_shape[1]
758
- other_head_indices = torch.tensor([h for h in range(n_heads) if h not in head_indices],
759
- device=device, dtype=torch.long)
760
- else:
761
- other_head_indices = head_idx_tensor
762
- else:
763
- other_head_indices = torch.arange(original_shape[1], device=device, dtype=torch.long)
764
-
765
- # Decompress sequence dimension data
766
- for precision_key in [k for k in compressed_data['keys'].keys() if k.startswith('seq_')]:
767
- if 'data' not in compressed_data['keys'][precision_key]:
768
- continue
769
-
770
- indices = compressed_data['keys'][precision_key]['indices']
771
- idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
772
-
773
- # All data stored as FP16 in this simplified version
774
- keys_full[:, other_head_indices, :, :].index_copy_(2, idx_tensor,
775
- compressed_data['keys'][precision_key]['data'])
776
- values_full[:, other_head_indices, :, :].index_copy_(2, idx_tensor,
777
- compressed_data['values'][precision_key]['data'])
778
-
779
- return keys_full, values_full
780
-
781
- def _decompress_hybrid_sparse_attention(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
782
- """Decompress RocketKV-style hybrid sparse attention data."""
783
- metadata = compressed_data['metadata']
784
- original_shape = metadata['original_shape']
785
-
786
- # Get device from first available tensor
787
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
788
- for head_key in compressed_data['keys'].keys():
789
- if head_key.startswith('head_'):
790
- device = compressed_data['keys'][head_key]['data'].device
791
- break
792
-
793
- # Initialize full tensors
794
- keys_full = torch.zeros(original_shape, dtype=torch.float16, device=device)
795
- values_full = torch.zeros(original_shape, dtype=torch.float16, device=device)
796
-
797
- # Reconstruct selected heads with their tokens
798
- for head_key in compressed_data['keys'].keys():
799
- if not head_key.startswith('head_'):
800
- continue
801
-
802
- head_idx = int(head_key.split('_')[1])
803
- head_data_k = compressed_data['keys'][head_key]
804
- head_data_v = compressed_data['values'][head_key]
805
-
806
- token_indices = head_data_k['indices']
807
-
808
- # Place data in the correct head and token positions
809
- keys_full[:, head_idx:head_idx+1, token_indices, :] = head_data_k['data']
810
- values_full[:, head_idx:head_idx+1, token_indices, :] = head_data_v['data']
811
-
812
- return keys_full, values_full
813
-
814
- def _decompress_original_spg(self, compressed_data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
815
- """Decompress original SPG data."""
816
- metadata = compressed_data['metadata']
817
- original_shape = metadata['original_shape']
818
- original_dtype = metadata['original_dtype']
819
- device = metadata['precision_scores'].device
820
-
821
- keys_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
822
- values_full = torch.zeros(original_shape, dtype=original_dtype, device=device)
823
-
824
- for precision_key in compressed_data['keys']:
825
- data_dict = compressed_data['keys'][precision_key]
826
- if 'data' in data_dict and 'indices' in data_dict:
827
- indices = data_dict['indices']
828
- idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
829
-
830
- # All data stored as original precision
831
- keys_full.index_copy_(2, idx_tensor, data_dict['data'])
832
- values_full.index_copy_(2, idx_tensor, compressed_data['values'][precision_key]['data'])
833
-
834
- return keys_full, values_full
835
-
836
- def get_memory_footprint(self, compressed_data: Dict[str, Any]) -> int:
837
- """
838
- Calculate ACTUAL memory usage - NO ESTIMATES.
839
- Every byte is accounted for explicitly.
840
- """
841
- total_bytes = 0
842
-
843
- try:
844
- # Count all stored tensors
845
- for storage_type in ['keys', 'values']:
846
- for key, data in compressed_data[storage_type].items():
847
- if isinstance(data, dict):
848
- # Data tensors
849
- if 'data' in data and isinstance(data['data'], torch.Tensor):
850
- total_bytes += data['data'].nelement() * data['data'].element_size()
851
-
852
- # Scale/zero tensors
853
- if 'scale' in data and isinstance(data['scale'], torch.Tensor):
854
- total_bytes += data['scale'].nelement() * data['scale'].element_size()
855
- if 'zero' in data and isinstance(data['zero'], torch.Tensor):
856
- total_bytes += data['zero'].nelement() * data['zero'].element_size()
857
-
858
- # Levels tensor for bit-packed data
859
- if 'levels' in data and isinstance(data['levels'], torch.Tensor):
860
- total_bytes += data['levels'].nelement() * data['levels'].element_size()
861
-
862
- # Metadata overhead (measured, not estimated)
863
- if 'meta' in data and isinstance(data['meta'], dict):
864
- total_bytes += self.constants.INT2_METADATA_BYTES
865
-
866
- # Indices (count only once under keys to avoid double counting)
867
- if storage_type == 'keys' and 'indices' in data and data['indices']:
868
- total_bytes += len(data['indices']) * self.constants.INDEX_SIZE_BYTES
869
-
870
- # Metadata overhead
871
- total_bytes += self.constants.METADATA_OVERHEAD_BYTES
872
-
873
- logger.debug(f"Measured memory footprint: {total_bytes} bytes ({total_bytes/1024/1024:.2f} MB)")
874
- return total_bytes
875
-
876
- except Exception as e:
877
- logger.error(f"Error calculating memory footprint: {e}")
878
- raise
879
-
880
- def update_quality_feedback(self, layer_idx: int, quality_metric: float):
881
- """Update quality feedback for progressive compression."""
882
- self.quality_history.append(quality_metric)
883
-
884
- # Keep only recent history
885
- if len(self.quality_history) > self.constants.QUALITY_HISTORY_MAX_SIZE:
886
- self.quality_history = self.quality_history[-self.constants.QUALITY_HISTORY_MAX_SIZE:]
887
-
888
-
889
- class QuantizedKVCache:
890
- """Enhanced quantized KV cache with working multi-stage SPG support."""
891
-
892
- def __init__(self, config: CompressionConfig):
893
- self.config = config
894
- self.compressed_data = {}
895
- self.dtypes = {}
896
-
897
- # Initialize enhanced SPG with RocketKV features
898
- if config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG]:
899
- spg_config = replace(config.enhanced_spg_config,
900
- enable_two_stage=False,
901
- enable_adaptive=(config.compression_type == CompressionType.ADAPTIVE_SPG))
902
- self.spg = EnhancedSlidingPrecisionGradient(spg_config)
903
- elif config.compression_type in [CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
904
- enhanced_config = config.enhanced_spg_config
905
- if config.compression_type == CompressionType.PROGRESSIVE_SPG:
906
- enhanced_config.enable_progressive = True
907
- self.spg = EnhancedSlidingPrecisionGradient(enhanced_config)
908
- else:
909
- self.spg = None
910
-
911
- self.current_position = 0
912
- self.quality_history = []
913
- self.n_layers = None
914
-
915
- def compress_and_store(self, layer_idx: int, keys: torch.Tensor, values: torch.Tensor):
916
- """Compress and store KV pairs with enhanced SPG support."""
917
- key_dtype = keys.dtype
918
- value_dtype = values.dtype
919
-
920
- if self.config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG,
921
- CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
922
- if self.spg.layer_decay_rates is None:
923
- if self.n_layers is None:
924
- raise ValueError("Model layer count not set - call detect_model_layers first")
925
- self.spg.initialize_layer_decay_rates(self.n_layers)
926
-
927
- if self.config.compression_type in [CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
928
- compressed_data = self.spg.compress_with_enhanced_gradient(
929
- keys, values, layer_idx, self.current_position
930
- )
931
- else:
932
- compressed_data = self.spg._fallback_to_original_spg(
933
- keys, values, layer_idx, self.current_position
934
- )
935
-
936
- self.compressed_data[layer_idx] = compressed_data
937
- self.dtypes[layer_idx] = {'keys': key_dtype, 'values': value_dtype}
938
- else:
939
- # No compression - store original tensors
940
- self.compressed_data[layer_idx] = {
941
- 'keys': {'original': {'data': keys.clone(), 'indices': list(range(keys.shape[2]))}},
942
- 'values': {'original': {'data': values.clone(), 'indices': list(range(values.shape[2]))}},
943
- 'metadata': {
944
- 'compression_type': 'none',
945
- 'original_shape': keys.shape,
946
- 'original_dtype': keys.dtype
947
- }
948
- }
949
- self.dtypes[layer_idx] = {'keys': key_dtype, 'values': value_dtype}
950
-
951
- def get_decompressed(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
952
- """Get decompressed KV pairs with enhanced SPG support."""
953
- if self.config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG,
954
- CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
955
- if layer_idx in self.compressed_data:
956
- return self.spg.decompress(self.compressed_data[layer_idx])
957
- return None, None
958
- else:
959
- # No compression - return original tensors
960
- if layer_idx in self.compressed_data:
961
- data = self.compressed_data[layer_idx]
962
- return data['keys']['original']['data'], data['values']['original']['data']
963
- return None, None
964
-
965
- def get_memory_footprint(self) -> int:
966
- """Calculate actual memory usage with enhanced SPG support."""
967
- total_bytes = 0
968
- constants = ResearchConstants()
969
-
970
- if self.config.compression_type in [CompressionType.SPG, CompressionType.ADAPTIVE_SPG,
971
- CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
972
- for layer_idx in self.compressed_data:
973
- total_bytes += self.spg.get_memory_footprint(self.compressed_data[layer_idx])
974
- else:
975
- # No compression - calculate uncompressed memory
976
- for layer_idx in self.compressed_data:
977
- data = self.compressed_data[layer_idx]
978
- keys_data = data['keys']['original']['data']
979
- values_data = data['values']['original']['data']
980
- total_bytes += keys_data.nelement() * keys_data.element_size()
981
- total_bytes += values_data.nelement() * values_data.element_size()
982
- total_bytes += constants.METADATA_OVERHEAD_BYTES
983
-
984
- return total_bytes
985
-
986
- def update_position(self, new_position: int):
987
- """Update current generation position."""
988
- self.current_position = new_position
989
-
990
- def update_quality_feedback(self, layer_idx: int, quality_metric: float):
991
- """Provide quality feedback for adaptive methods."""
992
- if self.config.compression_type == CompressionType.ADAPTIVE_SPG and hasattr(self.spg, 'update_decay_rate'):
993
- target_quality = self.config.enhanced_spg_config.target_perplexity_delta
994
- self.spg.update_decay_rate(layer_idx, quality_metric, target_quality)
995
- self.quality_history.append((layer_idx, quality_metric))
996
- elif self.config.compression_type in [CompressionType.ENHANCED_SPG, CompressionType.PROGRESSIVE_SPG]:
997
- self.spg.update_quality_feedback(layer_idx, quality_metric)
998
-
999
-
1000
- def detect_model_layers(model) -> int:
1001
- """Detect the number of transformer layers with comprehensive validation."""
1002
- config_attrs = [
1003
- 'num_hidden_layers',
1004
- 'n_layer',
1005
- 'num_layers',
1006
- 'n_layers',
1007
- 'decoder_layers',
1008
- 'n_head_layers',
1009
- ]
1010
-
1011
- for attr in config_attrs:
1012
- if hasattr(model.config, attr):
1013
- n_layers = getattr(model.config, attr)
1014
- if isinstance(n_layers, int) and n_layers > 0:
1015
- logger.info(f"Detected {n_layers} layers from config.{attr}")
1016
- return n_layers
1017
-
1018
- layer_patterns = [
1019
- 'layer', 'layers', 'h', 'blocks', 'decoder.layers', 'transformer_blocks', 'decoderLayer',
1020
- ]
1021
-
1022
- for module_name, module in model.named_modules():
1023
- for pattern in layer_patterns:
1024
- if pattern in module_name.lower():
1025
- if hasattr(module, '__len__'):
1026
- n_layers = len(module)
1027
- if n_layers > 0:
1028
- logger.info(f"Detected {n_layers} layers by counting {module_name}")
1029
- return n_layers
1030
-
1031
- decoder_layer_types = [
1032
- 'TransformerBlock', 'DecoderLayer', 'EncoderLayer', 'Block', 'Layer',
1033
- 'GPT2Block', 'LlamaDecoderLayer', 'MistralDecoderLayer', 'OPTDecoderLayer',
1034
- ]
1035
-
1036
- layers = []
1037
- for module in model.modules():
1038
- module_type = type(module).__name__
1039
- if any(layer_type in module_type for layer_type in decoder_layer_types):
1040
- layers.append(module)
1041
-
1042
- if layers:
1043
- n_layers = len(set(layers))
1044
- if n_layers > 0:
1045
- logger.info(f"Detected {n_layers} layers by module type matching")
1046
- return n_layers
1047
-
1048
- # Fail fast if cannot detect layers
1049
- raise ValueError(
1050
- f"Could not automatically detect the number of layers for model {type(model).__name__}. "
1051
- "Please check the model architecture and update the detection logic."
1052
- )