Update config.py
Browse files
config.py
CHANGED
|
@@ -429,23 +429,25 @@ class CompressionConfig:
|
|
| 429 |
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 430 |
|
| 431 |
def __post_init__(self):
|
| 432 |
-
"""Comprehensive validation -
|
| 433 |
constants = ResearchConstants()
|
| 434 |
|
| 435 |
-
# Set model name from key
|
| 436 |
if self.model_key not in SUPPORTED_MODELS:
|
| 437 |
raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
|
| 438 |
self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
|
|
|
|
| 439 |
|
| 440 |
-
# Validate benchmark type
|
| 441 |
if self.benchmark_type not in BENCHMARK_CONFIGS:
|
| 442 |
raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
|
|
|
|
| 443 |
|
| 444 |
-
# Validate core parameters
|
| 445 |
if not isinstance(self.seed, int) or self.seed < 0:
|
| 446 |
raise ValueError(f"seed must be non-negative integer, got {self.seed}")
|
| 447 |
|
| 448 |
-
# Validate evaluation parameters
|
| 449 |
if not constants.MIN_EVAL_SAMPLES <= self.eval_samples <= constants.MAX_EVAL_SAMPLES:
|
| 450 |
logger.warning(f"eval_samples {self.eval_samples} outside recommended range [{constants.MIN_EVAL_SAMPLES}, {constants.MAX_EVAL_SAMPLES}]")
|
| 451 |
|
|
@@ -458,17 +460,33 @@ class CompressionConfig:
|
|
| 458 |
if not 1 <= self.n_seeds <= 10:
|
| 459 |
logger.warning(f"n_seeds {self.n_seeds} outside recommended range [1, 10]")
|
| 460 |
|
| 461 |
-
# Validate statistical parameters
|
| 462 |
if not 0.5 <= self.confidence_level < 1.0:
|
| 463 |
raise ValueError(f"confidence_level must be in [0.5, 1.0), got {self.confidence_level}")
|
| 464 |
|
| 465 |
if not 100 <= self.n_bootstrap <= 10000:
|
| 466 |
logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
|
| 467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
# Pass Flash Attention setting to EnhancedSPGConfig
|
| 469 |
self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
|
| 470 |
|
| 471 |
-
logger.info("
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
def to_json(self) -> str:
|
| 474 |
"""Export config for reproducibility."""
|
|
|
|
| 429 |
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 430 |
|
| 431 |
def __post_init__(self):
|
| 432 |
+
"""Comprehensive validation - FAIL FAST on any invalid parameter, NO SILENT DEFAULTS."""
|
| 433 |
constants = ResearchConstants()
|
| 434 |
|
| 435 |
+
# Set model name from key - FAIL FAST if invalid
|
| 436 |
if self.model_key not in SUPPORTED_MODELS:
|
| 437 |
raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
|
| 438 |
self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
|
| 439 |
+
logger.info(f"Model selected: {self.model_name} (key: {self.model_key})")
|
| 440 |
|
| 441 |
+
# Validate benchmark type - FAIL FAST if invalid
|
| 442 |
if self.benchmark_type not in BENCHMARK_CONFIGS:
|
| 443 |
raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
|
| 444 |
+
logger.info(f"Benchmark selected: {self.benchmark_type}")
|
| 445 |
|
| 446 |
+
# Validate core parameters - NO MAGIC NUMBERS
|
| 447 |
if not isinstance(self.seed, int) or self.seed < 0:
|
| 448 |
raise ValueError(f"seed must be non-negative integer, got {self.seed}")
|
| 449 |
|
| 450 |
+
# Validate evaluation parameters with explicit bounds
|
| 451 |
if not constants.MIN_EVAL_SAMPLES <= self.eval_samples <= constants.MAX_EVAL_SAMPLES:
|
| 452 |
logger.warning(f"eval_samples {self.eval_samples} outside recommended range [{constants.MIN_EVAL_SAMPLES}, {constants.MAX_EVAL_SAMPLES}]")
|
| 453 |
|
|
|
|
| 460 |
if not 1 <= self.n_seeds <= 10:
|
| 461 |
logger.warning(f"n_seeds {self.n_seeds} outside recommended range [1, 10]")
|
| 462 |
|
| 463 |
+
# Validate statistical parameters - EXPLICIT BOUNDS
|
| 464 |
if not 0.5 <= self.confidence_level < 1.0:
|
| 465 |
raise ValueError(f"confidence_level must be in [0.5, 1.0), got {self.confidence_level}")
|
| 466 |
|
| 467 |
if not 100 <= self.n_bootstrap <= 10000:
|
| 468 |
logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
|
| 469 |
|
| 470 |
+
# Validate benchmark-specific parameters
|
| 471 |
+
if self.benchmark_type == "longbench" and not self.benchmark_subset:
|
| 472 |
+
logger.warning("LongBench selected but no subset specified")
|
| 473 |
+
|
| 474 |
+
if self.benchmark_type == "niah" and not self.niah_needle:
|
| 475 |
+
raise ValueError("NIAH benchmark requires niah_needle to be set")
|
| 476 |
+
|
| 477 |
+
if self.benchmark_type == "ruler" and self.ruler_max_seq_length <= 0:
|
| 478 |
+
raise ValueError(f"ruler_max_seq_length must be positive, got {self.ruler_max_seq_length}")
|
| 479 |
+
|
| 480 |
+
if self.benchmark_type == "scbench" and self.scbench_num_turns <= 0:
|
| 481 |
+
raise ValueError(f"scbench_num_turns must be positive, got {self.scbench_num_turns}")
|
| 482 |
+
|
| 483 |
# Pass Flash Attention setting to EnhancedSPGConfig
|
| 484 |
self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
|
| 485 |
|
| 486 |
+
logger.info("Configuration validated successfully - STRICT COMPLIANCE")
|
| 487 |
+
logger.info(f"Target compression: {self.enhanced_spg_config.target_compression_ratio}x")
|
| 488 |
+
logger.info(f"Fail on CPU fallback: {self.fail_on_cpu_fallback}")
|
| 489 |
+
logger.info(f"Proving enabled: {self.proving.enabled}")
|
| 490 |
|
| 491 |
def to_json(self) -> str:
|
| 492 |
"""Export config for reproducibility."""
|