Update config.py
Browse files
config.py
CHANGED
|
@@ -7,7 +7,7 @@ import json
|
|
| 7 |
import hashlib
|
| 8 |
from dataclasses import dataclass, field, asdict
|
| 9 |
from enum import Enum
|
| 10 |
-
from typing import List, Optional, NamedTuple
|
| 11 |
from datetime import datetime
|
| 12 |
import torch
|
| 13 |
import transformers
|
|
@@ -17,6 +17,69 @@ import logging
|
|
| 17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class CompressionType(Enum):
|
| 21 |
"""RocketKV-enhanced SPG methods with explicit validation."""
|
| 22 |
NONE = "none"
|
|
@@ -184,6 +247,9 @@ class EnhancedSPGConfig:
|
|
| 184 |
stage_compression_min: float = 2.0 # Minimum stage compression ratio
|
| 185 |
stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
| 187 |
def __post_init__(self):
|
| 188 |
"""Validate all parameters - fail fast on invalid config."""
|
| 189 |
constants = ResearchConstants()
|
|
@@ -304,6 +370,10 @@ class CompressionConfig:
|
|
| 304 |
compression_type: CompressionType = CompressionType.ENHANCED_SPG
|
| 305 |
seed: int = 42
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# Enhanced SPG configuration
|
| 308 |
enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
|
| 309 |
|
|
@@ -327,10 +397,25 @@ class CompressionConfig:
|
|
| 327 |
dataset_config: str = "wikitext-2-raw-v1"
|
| 328 |
dataset_split: str = "test"
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
# Memory and system settings
|
| 331 |
clear_cache_between_runs: bool = True
|
| 332 |
use_memory_snapshot: bool = True
|
| 333 |
fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
|
|
|
|
| 334 |
|
| 335 |
# Output settings
|
| 336 |
generate_latex: bool = True
|
|
@@ -347,6 +432,15 @@ class CompressionConfig:
|
|
| 347 |
"""Comprehensive validation - fail fast on any invalid parameter."""
|
| 348 |
constants = ResearchConstants()
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
# Validate core parameters
|
| 351 |
if not isinstance(self.seed, int) or self.seed < 0:
|
| 352 |
raise ValueError(f"seed must be non-negative integer, got {self.seed}")
|
|
@@ -371,6 +465,9 @@ class CompressionConfig:
|
|
| 371 |
if not 100 <= self.n_bootstrap <= 10000:
|
| 372 |
logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
|
| 373 |
|
|
|
|
|
|
|
|
|
|
| 374 |
logger.info("RocketKV-enhanced SPG config validated successfully")
|
| 375 |
|
| 376 |
def to_json(self) -> str:
|
|
|
|
| 7 |
import hashlib
|
| 8 |
from dataclasses import dataclass, field, asdict
|
| 9 |
from enum import Enum
|
| 10 |
+
from typing import List, Optional, NamedTuple, Dict, Any
|
| 11 |
from datetime import datetime
|
| 12 |
import torch
|
| 13 |
import transformers
|
|
|
|
| 17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
+
# Model configurations - NO HARDCODING
|
| 21 |
+
SUPPORTED_MODELS: Dict[str, Dict[str, Any]] = {
|
| 22 |
+
"gpt2": {
|
| 23 |
+
"name": "gpt2",
|
| 24 |
+
"requires_auth": False,
|
| 25 |
+
"max_context": 1024,
|
| 26 |
+
"default_dtype": "float16"
|
| 27 |
+
},
|
| 28 |
+
"llama2-7b": {
|
| 29 |
+
"name": "meta-llama/Llama-2-7b-hf",
|
| 30 |
+
"requires_auth": True,
|
| 31 |
+
"max_context": 4096,
|
| 32 |
+
"default_dtype": "float16"
|
| 33 |
+
},
|
| 34 |
+
"mistral-7b": {
|
| 35 |
+
"name": "mistralai/Mistral-7B-v0.1",
|
| 36 |
+
"requires_auth": False,
|
| 37 |
+
"max_context": 8192,
|
| 38 |
+
"default_dtype": "float16"
|
| 39 |
+
},
|
| 40 |
+
"opt-1.3b": {
|
| 41 |
+
"name": "facebook/opt-1.3b",
|
| 42 |
+
"requires_auth": False,
|
| 43 |
+
"max_context": 2048,
|
| 44 |
+
"default_dtype": "float16"
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Benchmark configurations - NO HARDCODING
|
| 49 |
+
BENCHMARK_CONFIGS: Dict[str, Dict[str, Any]] = {
|
| 50 |
+
"perplexity": {
|
| 51 |
+
"type": "perplexity",
|
| 52 |
+
"default_samples": 50,
|
| 53 |
+
"default_prefill": 512,
|
| 54 |
+
"default_generation": 64
|
| 55 |
+
},
|
| 56 |
+
"niah": {
|
| 57 |
+
"type": "needle_in_haystack",
|
| 58 |
+
"depths": [10, 25, 50, 75, 90], # Percentage depths
|
| 59 |
+
"needle": "The secret password is BANANA",
|
| 60 |
+
"default_samples": 10,
|
| 61 |
+
"default_context": 4096
|
| 62 |
+
},
|
| 63 |
+
"ruler": {
|
| 64 |
+
"type": "ruler",
|
| 65 |
+
"max_seq_lengths": [1024, 2048, 4096, 8192],
|
| 66 |
+
"default_samples": 10,
|
| 67 |
+
"default_n_facts": 10
|
| 68 |
+
},
|
| 69 |
+
"scbench": {
|
| 70 |
+
"type": "shared_context",
|
| 71 |
+
"num_turns": [5, 10, 20],
|
| 72 |
+
"default_samples": 10,
|
| 73 |
+
"default_context": 2048
|
| 74 |
+
},
|
| 75 |
+
"longbench": {
|
| 76 |
+
"type": "longbench",
|
| 77 |
+
"subsets": ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa"],
|
| 78 |
+
"default_samples": 20,
|
| 79 |
+
"max_context": 8192
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
class CompressionType(Enum):
|
| 84 |
"""RocketKV-enhanced SPG methods with explicit validation."""
|
| 85 |
NONE = "none"
|
|
|
|
| 247 |
stage_compression_min: float = 2.0 # Minimum stage compression ratio
|
| 248 |
stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
|
| 249 |
|
| 250 |
+
# Flash Attention support
|
| 251 |
+
use_flash_attention: bool = False # Try to use Flash Attention if available
|
| 252 |
+
|
| 253 |
def __post_init__(self):
|
| 254 |
"""Validate all parameters - fail fast on invalid config."""
|
| 255 |
constants = ResearchConstants()
|
|
|
|
| 370 |
compression_type: CompressionType = CompressionType.ENHANCED_SPG
|
| 371 |
seed: int = 42
|
| 372 |
|
| 373 |
+
# Model selection
|
| 374 |
+
model_key: str = "gpt2" # Key into SUPPORTED_MODELS
|
| 375 |
+
model_name: str = field(init=False) # Will be set in __post_init__
|
| 376 |
+
|
| 377 |
# Enhanced SPG configuration
|
| 378 |
enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
|
| 379 |
|
|
|
|
| 397 |
dataset_config: str = "wikitext-2-raw-v1"
|
| 398 |
dataset_split: str = "test"
|
| 399 |
|
| 400 |
+
# Benchmark configuration
|
| 401 |
+
benchmark_type: str = "perplexity" # perplexity, niah, ruler, scbench, longbench
|
| 402 |
+
benchmark_subset: Optional[str] = None # For longbench subsets
|
| 403 |
+
|
| 404 |
+
# NIAH-specific parameters
|
| 405 |
+
niah_needle: str = field(default_factory=lambda: BENCHMARK_CONFIGS["niah"]["needle"])
|
| 406 |
+
niah_depth_percent: float = 50.0
|
| 407 |
+
|
| 408 |
+
# RULER-specific parameters
|
| 409 |
+
ruler_max_seq_length: int = 4096
|
| 410 |
+
|
| 411 |
+
# SCBench-specific parameters
|
| 412 |
+
scbench_num_turns: int = 10
|
| 413 |
+
|
| 414 |
# Memory and system settings
|
| 415 |
clear_cache_between_runs: bool = True
|
| 416 |
use_memory_snapshot: bool = True
|
| 417 |
fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
|
| 418 |
+
use_flash_attention: bool = False # Try to use Flash Attention if available
|
| 419 |
|
| 420 |
# Output settings
|
| 421 |
generate_latex: bool = True
|
|
|
|
| 432 |
"""Comprehensive validation - fail fast on any invalid parameter."""
|
| 433 |
constants = ResearchConstants()
|
| 434 |
|
| 435 |
+
# Set model name from key
|
| 436 |
+
if self.model_key not in SUPPORTED_MODELS:
|
| 437 |
+
raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
|
| 438 |
+
self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
|
| 439 |
+
|
| 440 |
+
# Validate benchmark type
|
| 441 |
+
if self.benchmark_type not in BENCHMARK_CONFIGS:
|
| 442 |
+
raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
|
| 443 |
+
|
| 444 |
# Validate core parameters
|
| 445 |
if not isinstance(self.seed, int) or self.seed < 0:
|
| 446 |
raise ValueError(f"seed must be non-negative integer, got {self.seed}")
|
|
|
|
| 465 |
if not 100 <= self.n_bootstrap <= 10000:
|
| 466 |
logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
|
| 467 |
|
| 468 |
+
# Pass Flash Attention setting to EnhancedSPGConfig
|
| 469 |
+
self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
|
| 470 |
+
|
| 471 |
logger.info("RocketKV-enhanced SPG config validated successfully")
|
| 472 |
|
| 473 |
def to_json(self) -> str:
|