kfoughali commited on
Commit
89ac4ea
·
verified ·
1 Parent(s): 46afae7

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +25 -7
config.py CHANGED
@@ -429,23 +429,25 @@ class CompressionConfig:
429
  timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
430
 
431
  def __post_init__(self):
432
- """Comprehensive validation - fail fast on any invalid parameter."""
433
  constants = ResearchConstants()
434
 
435
- # Set model name from key
436
  if self.model_key not in SUPPORTED_MODELS:
437
  raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
438
  self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
 
439
 
440
- # Validate benchmark type
441
  if self.benchmark_type not in BENCHMARK_CONFIGS:
442
  raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
 
443
 
444
- # Validate core parameters
445
  if not isinstance(self.seed, int) or self.seed < 0:
446
  raise ValueError(f"seed must be non-negative integer, got {self.seed}")
447
 
448
- # Validate evaluation parameters
449
  if not constants.MIN_EVAL_SAMPLES <= self.eval_samples <= constants.MAX_EVAL_SAMPLES:
450
  logger.warning(f"eval_samples {self.eval_samples} outside recommended range [{constants.MIN_EVAL_SAMPLES}, {constants.MAX_EVAL_SAMPLES}]")
451
 
@@ -458,17 +460,33 @@ class CompressionConfig:
458
  if not 1 <= self.n_seeds <= 10:
459
  logger.warning(f"n_seeds {self.n_seeds} outside recommended range [1, 10]")
460
 
461
- # Validate statistical parameters
462
  if not 0.5 <= self.confidence_level < 1.0:
463
  raise ValueError(f"confidence_level must be in [0.5, 1.0), got {self.confidence_level}")
464
 
465
  if not 100 <= self.n_bootstrap <= 10000:
466
  logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  # Pass Flash Attention setting to EnhancedSPGConfig
469
  self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
470
 
471
- logger.info("RocketKV-enhanced SPG config validated successfully")
 
 
 
472
 
473
  def to_json(self) -> str:
474
  """Export config for reproducibility."""
 
429
  timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
430
 
431
  def __post_init__(self):
432
+ """Comprehensive validation - FAIL FAST on any invalid parameter, NO SILENT DEFAULTS."""
433
  constants = ResearchConstants()
434
 
435
+ # Set model name from key - FAIL FAST if invalid
436
  if self.model_key not in SUPPORTED_MODELS:
437
  raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
438
  self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
439
+ logger.info(f"Model selected: {self.model_name} (key: {self.model_key})")
440
 
441
+ # Validate benchmark type - FAIL FAST if invalid
442
  if self.benchmark_type not in BENCHMARK_CONFIGS:
443
  raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
444
+ logger.info(f"Benchmark selected: {self.benchmark_type}")
445
 
446
+ # Validate core parameters - NO MAGIC NUMBERS
447
  if not isinstance(self.seed, int) or self.seed < 0:
448
  raise ValueError(f"seed must be non-negative integer, got {self.seed}")
449
 
450
+ # Validate evaluation parameters with explicit bounds
451
  if not constants.MIN_EVAL_SAMPLES <= self.eval_samples <= constants.MAX_EVAL_SAMPLES:
452
  logger.warning(f"eval_samples {self.eval_samples} outside recommended range [{constants.MIN_EVAL_SAMPLES}, {constants.MAX_EVAL_SAMPLES}]")
453
 
 
460
  if not 1 <= self.n_seeds <= 10:
461
  logger.warning(f"n_seeds {self.n_seeds} outside recommended range [1, 10]")
462
 
463
+ # Validate statistical parameters - EXPLICIT BOUNDS
464
  if not 0.5 <= self.confidence_level < 1.0:
465
  raise ValueError(f"confidence_level must be in [0.5, 1.0), got {self.confidence_level}")
466
 
467
  if not 100 <= self.n_bootstrap <= 10000:
468
  logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
469
 
470
+ # Validate benchmark-specific parameters
471
+ if self.benchmark_type == "longbench" and not self.benchmark_subset:
472
+ logger.warning("LongBench selected but no subset specified")
473
+
474
+ if self.benchmark_type == "niah" and not self.niah_needle:
475
+ raise ValueError("NIAH benchmark requires niah_needle to be set")
476
+
477
+ if self.benchmark_type == "ruler" and self.ruler_max_seq_length <= 0:
478
+ raise ValueError(f"ruler_max_seq_length must be positive, got {self.ruler_max_seq_length}")
479
+
480
+ if self.benchmark_type == "scbench" and self.scbench_num_turns <= 0:
481
+ raise ValueError(f"scbench_num_turns must be positive, got {self.scbench_num_turns}")
482
+
483
  # Pass Flash Attention setting to EnhancedSPGConfig
484
  self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
485
 
486
+ logger.info("Configuration validated successfully - STRICT COMPLIANCE")
487
+ logger.info(f"Target compression: {self.enhanced_spg_config.target_compression_ratio}x")
488
+ logger.info(f"Fail on CPU fallback: {self.fail_on_cpu_fallback}")
489
+ logger.info(f"Proving enabled: {self.proving.enabled}")
490
 
491
  def to_json(self) -> str:
492
  """Export config for reproducibility."""