| """Cross-validation module tests. |
| |
| Run from project root: |
| python3 scripts/test_cv.py |
| """ |
| from __future__ import annotations |
|
|
| import sys |
| import traceback |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(ROOT / "src")) |
|
|
| import numpy as np |
| import torch |
| from dce_analyzer.simulate import generate_simulated_dce |
| from dce_analyzer.config import FullModelSpec, VariableSpec |
| from dce_analyzer.cross_validation import cross_validate, CrossValidationResult |
|
|
| |
| |
| |
|
|
| _results: list[tuple[str, bool, str]] = [] |
|
|
|
|
| def _run(name: str, fn): |
| """Run *fn* and record PASS / FAIL.""" |
| try: |
| fn() |
| _results.append((name, True, "")) |
| print(f" PASS {name}") |
| except Exception as exc: |
| msg = f"{exc.__class__.__name__}: {exc}" |
| _results.append((name, False, msg)) |
| print(f" FAIL {name}") |
| traceback.print_exc() |
| print() |
|
|
|
|
| |
| |
| |
|
|
| sim = generate_simulated_dce(n_individuals=30, n_tasks=4, n_alts=3, seed=42) |
| DF = sim.data |
|
|
| VARS_FIXED = [ |
| VariableSpec(name="price", column="price", distribution="fixed"), |
| VariableSpec(name="time", column="time", distribution="fixed"), |
| VariableSpec(name="comfort", column="comfort", distribution="fixed"), |
| ] |
|
|
| SPEC_CL = FullModelSpec( |
| id_col="respondent_id", |
| task_col="task_id", |
| alt_col="alternative", |
| choice_col="choice", |
| variables=VARS_FIXED, |
| model_type="conditional", |
| maxiter=100, |
| ) |
|
|
| CPU = torch.device("cpu") |
|
|
| |
| _cl_result: CrossValidationResult | None = None |
|
|
| |
| |
| |
|
|
| def test_kfold_split_preserves_all_individuals(): |
| """1. K-fold split preserves all individuals.""" |
| unique_ids = DF["respondent_id"].unique() |
| rng = np.random.default_rng(42) |
| ids_copy = unique_ids.copy() |
| rng.shuffle(ids_copy) |
| k = 5 |
| folds = np.array_split(ids_copy, k) |
|
|
| |
| union = set() |
| for fold in folds: |
| union.update(fold.tolist()) |
| assert union == set(unique_ids), ( |
| f"Union of folds ({len(union)}) != all IDs ({len(unique_ids)})" |
| ) |
|
|
|
|
| _run("1. K-fold split preserves all individuals", test_kfold_split_preserves_all_individuals) |
|
|
|
|
| def test_no_individual_in_both_train_and_test(): |
| """2. No individual appears in both train and test for any fold.""" |
| unique_ids = DF["respondent_id"].unique() |
| rng = np.random.default_rng(42) |
| ids_copy = unique_ids.copy() |
| rng.shuffle(ids_copy) |
| k = 5 |
| folds = np.array_split(ids_copy, k) |
|
|
| for fold_idx in range(k): |
| test_ids = set(folds[fold_idx].tolist()) |
| train_ids = set(unique_ids) - test_ids |
| overlap = test_ids & train_ids |
| assert len(overlap) == 0, ( |
| f"Fold {fold_idx}: overlap={overlap}" |
| ) |
|
|
|
|
| _run("2. No individual in both train and test", test_no_individual_in_both_train_and_test) |
|
|
|
|
| def test_cv_conditional_logit(): |
| """3. CV with Conditional Logit (3-fold).""" |
| global _cl_result |
| result = cross_validate(DF, SPEC_CL, k=3, seed=42, device=CPU) |
| _cl_result = result |
|
|
| assert isinstance(result, CrossValidationResult), "Wrong return type" |
| assert result.k == 3, f"Expected k=3, got k={result.k}" |
| assert len(result.fold_results) == 3, ( |
| f"Expected 3 fold results, got {len(result.fold_results)}" |
| ) |
| assert result.mean_test_ll < 0, ( |
| f"Expected negative mean test LL, got {result.mean_test_ll}" |
| ) |
| assert result.model_type == "conditional" |
| assert result.total_runtime > 0 |
|
|
|
|
| _run("3. CV with Conditional Logit (3-fold)", test_cv_conditional_logit) |
|
|
|
|
| def test_cv_mixed_logit(): |
| """4. CV with Mixed Logit (3-fold).""" |
| vars_random = [ |
| VariableSpec(name="price", column="price", distribution="normal"), |
| VariableSpec(name="time", column="time", distribution="normal"), |
| VariableSpec(name="comfort", column="comfort", distribution="fixed"), |
| ] |
| spec_mxl = FullModelSpec( |
| id_col="respondent_id", |
| task_col="task_id", |
| alt_col="alternative", |
| choice_col="choice", |
| variables=vars_random, |
| model_type="mixed", |
| n_draws=50, |
| maxiter=50, |
| ) |
| result = cross_validate(DF, spec_mxl, k=3, seed=42, device=CPU) |
| assert isinstance(result, CrossValidationResult) |
| assert result.k == 3 |
| assert len(result.fold_results) == 3 |
| assert result.model_type == "mixed" |
|
|
|
|
| _run("4. CV with Mixed Logit (3-fold)", test_cv_mixed_logit) |
|
|
|
|
| def test_hit_rate_bounds(): |
| """5. Hit rate is between 0 and 1.""" |
| assert _cl_result is not None, "Test 3 must pass first (CL result needed)" |
| for fr in _cl_result.fold_results: |
| assert 0.0 <= fr.hit_rate <= 1.0, f"Fold {fr.fold}: hit_rate={fr.hit_rate}" |
| assert 0.0 <= _cl_result.mean_hit_rate <= 1.0, ( |
| f"mean_hit_rate={_cl_result.mean_hit_rate}" |
| ) |
|
|
|
|
| _run("5. Hit rate is between 0 and 1", test_hit_rate_bounds) |
|
|
|
|
| def test_k_greater_than_n_individuals_raises(): |
| """6. K > n_individuals raises an error.""" |
| sim_small = generate_simulated_dce(n_individuals=10, n_tasks=4, n_alts=3, seed=99) |
| df_small = sim_small.data |
| spec_small = FullModelSpec( |
| id_col="respondent_id", |
| task_col="task_id", |
| alt_col="alternative", |
| choice_col="choice", |
| variables=VARS_FIXED, |
| model_type="conditional", |
| maxiter=50, |
| ) |
| raised = False |
| try: |
| cross_validate(df_small, spec_small, k=100, seed=42, device=CPU) |
| except ValueError as e: |
| raised = True |
| assert "k=100" in str(e), f"Unexpected error message: {e}" |
| assert raised, "Expected ValueError when K > n_individuals" |
|
|
|
|
| _run("6. K > n_individuals raises error", test_k_greater_than_n_individuals_raises) |
|
|
|
|
| def test_progress_callback_called(): |
| """7. Progress callback is called K times.""" |
| calls = [] |
|
|
| def callback(fold_idx, k, status): |
| calls.append((fold_idx, k, status)) |
|
|
| k = 3 |
| cross_validate(DF, SPEC_CL, k=k, seed=42, device=CPU, progress_callback=callback) |
| assert len(calls) == k, f"Expected {k} callback calls, got {len(calls)}" |
| for i, (fold_idx, k_val, status) in enumerate(calls): |
| assert fold_idx == i, f"Expected fold_idx={i}, got {fold_idx}" |
| assert k_val == k, f"Expected k={k}, got {k_val}" |
|
|
|
|
| _run("7. Progress callback is called K times", test_progress_callback_called) |
|
|
|
|
| |
| |
| |
|
|
| def _summary(): |
| total = len(_results) |
| passed = sum(1 for _, ok, _ in _results if ok) |
| failed = total - passed |
| print(f"\n{'='*60}") |
| print(f" {passed} passed, {failed} failed out of {total} tests") |
| print(f"{'='*60}") |
| if failed == 0: |
| print(" ALL TESTS PASSED") |
| else: |
| print("\n Failed tests:") |
| for name, ok, msg in _results: |
| if not ok: |
| print(f" - {name}: {msg}") |
| return failed == 0 |
|
|
|
|
| if __name__ == "__main__": |
| success = _summary() |
| sys.exit(0 if success else 1) |
|
|