Researcher / tests /test_scoring_models.py
amarck's picture
Add HF Spaces support, preference seeding, archive search, tests
430d0f8
"""Tests for configurable model selection: config, scoring, pipeline integration."""
import json
import os
import sqlite3
import sys
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _isolate_config(tmp_path, monkeypatch):
"""Redirect config.yaml and DB to temp dir so tests are hermetic."""
config_path = tmp_path / "config.yaml"
db_path = tmp_path / "researcher.db"
monkeypatch.setenv("CONFIG_PATH", str(config_path))
monkeypatch.setenv("DB_PATH", str(db_path))
# Ensure ANTHROPIC_API_KEY is set for tests that need it
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key")
@pytest.fixture()
def fresh_config(tmp_path):
"""Return a helper that writes a config.yaml and reloads the config module."""
config_path = tmp_path / "config.yaml"
def _write(data: dict):
import yaml
config_path.write_text(yaml.dump(data, default_flow_style=False))
# Force config module to reload from this file
import src.config as cfg
cfg.CONFIG_PATH = config_path
cfg._cfg = cfg._load_yaml()
sc = cfg._cfg.get("scoring", {})
cfg.SCORING_MODEL = sc.get("model", cfg._cfg.get("claude_model", "claude-haiku-4-5-20251001"))
cfg.RESCORE_MODEL = sc.get("rescore_model", "claude-sonnet-4-5-20250929")
cfg.RESCORE_TOP_N = sc.get("rescore_top_n", 15)
cfg.BATCH_SIZE = sc.get("batch_size", cfg._cfg.get("batch_size", 20))
cfg.SCORING_CONFIGS.update(cfg._build_scoring_configs())
return _write
@pytest.fixture()
def test_db(tmp_path):
"""Initialize a temp database and return its path."""
db_path = tmp_path / "researcher.db"
import src.config as cfg
cfg.DB_PATH = db_path
from src.db import init_db
init_db()
return db_path
def _insert_test_papers(db_path, run_id, domain, papers):
"""Insert test papers directly into the DB."""
conn = sqlite3.connect(str(db_path))
for p in papers:
conn.execute(
"""INSERT INTO papers
(run_id, domain, arxiv_id, entry_id, title, authors, abstract,
published, categories, pdf_url, arxiv_url, comment, source,
github_repo, github_stars, hf_upvotes, hf_models, hf_datasets, hf_spaces,
score_axis_1, score_axis_2, score_axis_3, composite, summary, reasoning, code_url)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
(
run_id, domain,
p.get("arxiv_id", ""),
p.get("entry_id", ""),
p.get("title", ""),
json.dumps(p.get("authors", [])),
p.get("abstract", ""),
p.get("published", ""),
json.dumps(p.get("categories", [])),
p.get("pdf_url", ""),
p.get("arxiv_url", ""),
p.get("comment", ""),
p.get("source", ""),
p.get("github_repo", ""),
p.get("github_stars"),
p.get("hf_upvotes", 0),
json.dumps(p.get("hf_models", [])),
json.dumps(p.get("hf_datasets", [])),
json.dumps(p.get("hf_spaces", [])),
p.get("score_axis_1"),
p.get("score_axis_2"),
p.get("score_axis_3"),
p.get("composite"),
p.get("summary", ""),
p.get("reasoning", ""),
p.get("code_url"),
),
)
conn.commit()
conn.close()
def _create_test_run(db_path, domain):
"""Create a run row and return its id."""
conn = sqlite3.connect(str(db_path))
cur = conn.execute(
"INSERT INTO runs (domain, started_at, date_start, date_end, status) "
"VALUES (?, '2026-01-01T00:00:00', '2026-01-01', '2026-01-07', 'completed')",
(domain,),
)
run_id = cur.lastrowid
conn.commit()
conn.close()
return run_id
# ---------------------------------------------------------------------------
# 1. Config defaults
# ---------------------------------------------------------------------------
class TestConfigDefaults:
"""Verify default model constants when no config.yaml exists."""
def test_default_scoring_model(self):
import src.config as cfg
# Without a config file, should default to haiku
assert "haiku" in cfg.SCORING_MODEL
def test_default_rescore_model(self):
import src.config as cfg
assert "sonnet" in cfg.RESCORE_MODEL
def test_default_rescore_top_n(self):
import src.config as cfg
assert cfg.RESCORE_TOP_N == 15
def test_default_batch_size(self):
import src.config as cfg
assert cfg.BATCH_SIZE == 20
# ---------------------------------------------------------------------------
# 2. Config loading from YAML
# ---------------------------------------------------------------------------
class TestConfigYAML:
"""Verify config values load from config.yaml correctly."""
def test_scoring_block_loads(self, fresh_config):
import src.config as cfg
fresh_config({
"scoring": {
"model": "claude-opus-4-6",
"rescore_model": "claude-sonnet-4-5-20250929",
"rescore_top_n": 25,
"batch_size": 10,
},
})
assert cfg.SCORING_MODEL == "claude-opus-4-6"
assert cfg.RESCORE_MODEL == "claude-sonnet-4-5-20250929"
assert cfg.RESCORE_TOP_N == 25
assert cfg.BATCH_SIZE == 10
def test_backward_compat_claude_model_key(self, fresh_config):
"""Old `claude_model` key is used as SCORING_MODEL fallback."""
import src.config as cfg
fresh_config({"claude_model": "claude-sonnet-4-5-20250929"})
assert cfg.SCORING_MODEL == "claude-sonnet-4-5-20250929"
# rescore defaults still apply
assert "sonnet" in cfg.RESCORE_MODEL
def test_scoring_block_overrides_claude_model(self, fresh_config):
import src.config as cfg
fresh_config({
"claude_model": "claude-sonnet-4-5-20250929",
"scoring": {"model": "claude-haiku-4-5-20251001"},
})
assert cfg.SCORING_MODEL == "claude-haiku-4-5-20251001"
def test_rescore_disabled_when_zero(self, fresh_config):
import src.config as cfg
fresh_config({"scoring": {"rescore_top_n": 0}})
assert cfg.RESCORE_TOP_N == 0
# ---------------------------------------------------------------------------
# 3. save_config reloads model constants
# ---------------------------------------------------------------------------
class TestSaveConfig:
"""Verify save_config() updates module-level constants in config."""
def test_save_updates_scoring_model(self, tmp_path):
import src.config as cfg
cfg.CONFIG_PATH = tmp_path / "config.yaml"
cfg.save_config({
"scoring": {
"model": "claude-opus-4-6",
"rescore_model": "claude-haiku-4-5-20251001",
"rescore_top_n": 5,
"batch_size": 30,
},
})
assert cfg.SCORING_MODEL == "claude-opus-4-6"
assert cfg.RESCORE_MODEL == "claude-haiku-4-5-20251001"
assert cfg.RESCORE_TOP_N == 5
assert cfg.BATCH_SIZE == 30
def test_save_without_scoring_block_uses_defaults(self, tmp_path):
import src.config as cfg
cfg.CONFIG_PATH = tmp_path / "config.yaml"
cfg.save_config({"domains": {}})
assert "haiku" in cfg.SCORING_MODEL
assert "sonnet" in cfg.RESCORE_MODEL
assert cfg.RESCORE_TOP_N == 15
# ---------------------------------------------------------------------------
# 4. scoring.py reads live config (not stale bindings)
# ---------------------------------------------------------------------------
class TestScoringLiveConfig:
"""Verify scoring.py reads config values at call time, not import time."""
def test_score_run_reads_live_model(self, test_db, tmp_path):
"""After config change, score_run uses the new model."""
import src.config as cfg
cfg.CONFIG_PATH = tmp_path / "config.yaml"
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
# Start with haiku
cfg.SCORING_MODEL = "claude-haiku-4-5-20251001"
run_id = _create_test_run(test_db, "aiml")
_insert_test_papers(test_db, run_id, "aiml", [{
"arxiv_id": "2601.00001",
"title": "Test Paper",
"abstract": "Test abstract",
"authors": ["Author A"],
"categories": ["cs.LG"],
}])
# Change config to sonnet
cfg.SCORING_MODEL = "claude-sonnet-4-5-20250929"
# Mock the Anthropic client to capture which model is used
captured_model = {}
def mock_create(**kwargs):
captured_model["model"] = kwargs["model"]
resp = MagicMock()
resp.content = [MagicMock(text='[{"arxiv_id":"2601.00001","code_and_weights":7,"novelty":8,"practical_applicability":6,"summary":"test","reasoning":"test","code_url":null}]')]
return resp
with patch("anthropic.Anthropic") as MockClient:
mock_instance = MagicMock()
mock_instance.messages.create = mock_create
MockClient.return_value = mock_instance
from src.scoring import score_run
score_run(run_id, "aiml")
assert captured_model["model"] == "claude-sonnet-4-5-20250929"
def test_rescore_reads_live_config(self, test_db, tmp_path):
"""rescore_top reads RESCORE_MODEL and RESCORE_TOP_N at call time."""
import src.config as cfg
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
cfg.SCORING_MODEL = "claude-haiku-4-5-20251001"
cfg.RESCORE_MODEL = "claude-opus-4-6"
cfg.RESCORE_TOP_N = 3
run_id = _create_test_run(test_db, "aiml")
# Insert scored papers
for i in range(5):
_insert_test_papers(test_db, run_id, "aiml", [{
"arxiv_id": f"2601.{i:05d}",
"title": f"Paper {i}",
"abstract": f"Abstract {i}",
"authors": ["Author"],
"categories": ["cs.LG"],
"composite": 8.0 - i * 0.5,
"score_axis_1": 7, "score_axis_2": 8, "score_axis_3": 6,
"summary": "existing", "reasoning": "existing",
}])
captured = {}
def mock_create(**kwargs):
captured["model"] = kwargs["model"]
captured["content"] = kwargs["messages"][0]["content"]
resp = MagicMock()
# Return scores for top 3
results = []
for i in range(3):
results.append({
"arxiv_id": f"2601.{i:05d}",
"code_and_weights": 9, "novelty": 9, "practical_applicability": 9,
"summary": "rescored", "reasoning": "rescored", "code_url": None,
})
resp.content = [MagicMock(text=json.dumps(results))]
return resp
with patch("anthropic.Anthropic") as MockClient:
mock_instance = MagicMock()
mock_instance.messages.create = mock_create
MockClient.return_value = mock_instance
from src.scoring import rescore_top
count = rescore_top(run_id, "aiml")
assert captured["model"] == "claude-opus-4-6"
assert count == 3
# ---------------------------------------------------------------------------
# 5. rescore_top guard conditions
# ---------------------------------------------------------------------------
class TestRescoreGuards:
"""Test rescore_top early-exit conditions."""
def test_rescore_disabled_when_n_zero(self, test_db):
import src.config as cfg
cfg.RESCORE_TOP_N = 0
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
from src.scoring import rescore_top
assert rescore_top(1, "aiml") == 0
def test_rescore_disabled_when_explicit_n_zero(self, test_db):
import src.config as cfg
cfg.RESCORE_TOP_N = 15 # config says 15
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
from src.scoring import rescore_top
assert rescore_top(1, "aiml", n=0) == 0 # n=0 falls through to config
def test_rescore_skipped_when_same_model(self, test_db):
import src.config as cfg
cfg.SCORING_MODEL = "claude-haiku-4-5-20251001"
cfg.RESCORE_MODEL = "claude-haiku-4-5-20251001"
cfg.RESCORE_TOP_N = 15
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
from src.scoring import rescore_top
assert rescore_top(1, "aiml") == 0
def test_rescore_skipped_when_no_api_key(self, test_db):
import src.config as cfg
cfg.ANTHROPIC_API_KEY = ""
cfg.SCORING_MODEL = "claude-haiku-4-5-20251001"
cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929"
cfg.RESCORE_TOP_N = 15
from src.scoring import rescore_top
assert rescore_top(1, "aiml") == 0
def test_rescore_skipped_when_no_papers(self, test_db):
import src.config as cfg
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
cfg.SCORING_MODEL = "claude-haiku-4-5-20251001"
cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929"
cfg.RESCORE_TOP_N = 15
run_id = _create_test_run(test_db, "aiml")
# No papers inserted
from src.scoring import rescore_top
assert rescore_top(run_id, "aiml") == 0
def test_rescore_explicit_n_overrides_config(self, test_db):
"""Passing n=X should use that instead of RESCORE_TOP_N."""
import src.config as cfg
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
cfg.SCORING_MODEL = "claude-haiku-4-5-20251001"
cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929"
cfg.RESCORE_TOP_N = 15
run_id = _create_test_run(test_db, "aiml")
# Insert 5 scored papers
for i in range(5):
_insert_test_papers(test_db, run_id, "aiml", [{
"arxiv_id": f"2601.{i:05d}",
"title": f"Paper {i}",
"abstract": f"Abstract {i}",
"authors": ["Author"],
"categories": ["cs.LG"],
"composite": 8.0 - i * 0.5,
"score_axis_1": 7, "score_axis_2": 8, "score_axis_3": 6,
}])
captured_content = {}
def mock_create(**kwargs):
captured_content["text"] = kwargs["messages"][0]["content"]
results = []
for i in range(2):
results.append({
"arxiv_id": f"2601.{i:05d}",
"code_and_weights": 9, "novelty": 9, "practical_applicability": 9,
"summary": "r", "reasoning": "r", "code_url": None,
})
resp = MagicMock()
resp.content = [MagicMock(text=json.dumps(results))]
return resp
with patch("anthropic.Anthropic") as MockClient:
mock_instance = MagicMock()
mock_instance.messages.create = mock_create
MockClient.return_value = mock_instance
from src.scoring import rescore_top
count = rescore_top(run_id, "aiml", n=2)
# Should have only sent 2 papers (not 15 from config)
assert captured_content["text"].count("arxiv_id:") == 2
assert count == 2
# ---------------------------------------------------------------------------
# 6. _build_batch_content output format
# ---------------------------------------------------------------------------
class TestBuildBatchContent:
"""Verify _build_batch_content sends the right fields for each domain."""
def test_aiml_content_fields(self):
from src.scoring import _build_batch_content
papers = [{
"arxiv_id": "2601.12345",
"title": "Great New Model",
"authors": ["Alice", "Bob", "Carol"],
"categories": ["cs.LG", "cs.CL"],
"abstract": "We present a new model.",
"comment": "Accepted at ICML 2026",
"github_repo": "https://github.com/alice/model",
"hf_upvotes": 120,
"hf_models": [{"id": "alice/model-v1", "likes": 50}],
"hf_spaces": [{"id": "alice/demo", "likes": 10}],
"source": "both",
}]
content = _build_batch_content(papers, "aiml", 2000)
assert "arxiv_id: 2601.12345" in content
assert "title: Great New Model" in content
assert "authors: Alice, Bob, Carol" in content
assert "categories: cs.LG, cs.CL" in content
assert "code_url_found: https://github.com/alice/model" in content
assert "hf_upvotes: 120" in content
assert "hf_models: alice/model-v1" in content
assert "hf_spaces: alice/demo" in content
assert "source: both" in content
assert "abstract: We present a new model." in content
assert "comment: Accepted at ICML 2026" in content
# Should NOT have security-only fields
assert "entry_id:" not in content
assert "llm_adjacent:" not in content
def test_security_content_fields(self):
from src.scoring import _build_batch_content
papers = [{
"entry_id": "http://arxiv.org/abs/2601.99999v1",
"arxiv_id": "2601.99999",
"title": "New Kernel Exploit",
"authors": ["Mallory"],
"categories": ["cs.CR"],
"abstract": "We found a buffer overflow in the Linux kernel.",
"comment": "10 pages",
"github_repo": "https://github.com/mallory/poc",
}]
content = _build_batch_content(papers, "security", 1500)
assert "entry_id: http://arxiv.org/abs/2601.99999v1" in content
assert "title: New Kernel Exploit" in content
assert "code_url_found: https://github.com/mallory/poc" in content
assert "llm_adjacent: false" in content
# Should NOT have aiml-only fields
assert "hf_upvotes:" not in content
assert "source:" not in content
def test_security_llm_adjacent_true(self):
from src.scoring import _build_batch_content
papers = [{
"entry_id": "http://arxiv.org/abs/2601.88888v1",
"arxiv_id": "2601.88888",
"title": "Jailbreaking Large Language Models",
"authors": ["Eve"],
"categories": ["cs.CR"],
"abstract": "We demonstrate a new jailbreak attack on LLMs.",
"comment": "",
}]
content = _build_batch_content(papers, "security", 1500)
assert "llm_adjacent: true" in content
def test_abstract_truncation(self):
from src.scoring import _build_batch_content
long_abstract = "A" * 5000
papers = [{
"arxiv_id": "2601.00001",
"title": "T",
"abstract": long_abstract,
"authors": [],
"categories": [],
}]
content = _build_batch_content(papers, "aiml", 2000)
# Abstract should be truncated to 2000 chars
assert f"abstract: {'A' * 2000}" in content
assert "A" * 2001 not in content
def test_missing_code_url(self):
from src.scoring import _build_batch_content
papers = [{
"arxiv_id": "2601.00001",
"title": "No Code Paper",
"abstract": "Theory only.",
"authors": [],
"categories": [],
}]
content = _build_batch_content(papers, "aiml", 2000)
assert "code_url_found: none found" in content
# ---------------------------------------------------------------------------
# 7. _apply_scores integration
# ---------------------------------------------------------------------------
class TestApplyScores:
"""Test score application and composite calculation."""
def test_aiml_score_application(self, test_db):
import src.config as cfg
run_id = _create_test_run(test_db, "aiml")
_insert_test_papers(test_db, run_id, "aiml", [{
"arxiv_id": "2601.00001",
"title": "Test",
"abstract": "Test",
"authors": ["A"],
"categories": ["cs.LG"],
}])
# Get the paper to know its DB id
from src.db import get_unscored_papers
papers = get_unscored_papers(run_id)
assert len(papers) == 1
scoring_config = cfg.SCORING_CONFIGS["aiml"]
claude_scores = [{
"arxiv_id": "2601.00001",
"code_and_weights": 8,
"novelty": 7,
"practical_applicability": 9,
"summary": "Great paper",
"reasoning": "Novel approach",
"code_url": "https://github.com/test/repo",
}]
from src.scoring import _apply_scores
applied = _apply_scores(papers, claude_scores, "aiml", scoring_config)
assert applied == 1
# Verify DB was updated
from src.db import get_top_papers
scored = get_top_papers("aiml", run_id=run_id, limit=1)
assert len(scored) == 1
assert scored[0]["summary"] == "Great paper"
assert scored[0]["code_url"] == "https://github.com/test/repo"
assert scored[0]["composite"] > 0
def test_security_score_application(self, test_db):
import src.config as cfg
run_id = _create_test_run(test_db, "security")
_insert_test_papers(test_db, run_id, "security", [{
"arxiv_id": "2601.99999",
"entry_id": "http://arxiv.org/abs/2601.99999v1",
"title": "Exploit",
"abstract": "Buffer overflow",
"authors": ["M"],
"categories": ["cs.CR"],
}])
from src.db import get_unscored_papers
papers = get_unscored_papers(run_id)
assert len(papers) == 1
scoring_config = cfg.SCORING_CONFIGS["security"]
claude_scores = [{
"entry_id": "http://arxiv.org/abs/2601.99999v1",
"has_code": 6,
"novel_attack_surface": 9,
"real_world_impact": 8,
"summary": "Kernel exploit",
"reasoning": "Critical",
"code_url": None,
}]
from src.scoring import _apply_scores
applied = _apply_scores(papers, claude_scores, "security", scoring_config)
assert applied == 1
from src.db import get_top_papers
scored = get_top_papers("security", run_id=run_id, limit=1)
assert len(scored) == 1
assert scored[0]["summary"] == "Kernel exploit"
# ---------------------------------------------------------------------------
# 8. _call_claude model parameter
# ---------------------------------------------------------------------------
class TestCallClaude:
"""Test _call_claude passes the model correctly and handles responses."""
def test_model_passed_through(self):
captured = {}
def mock_create(**kwargs):
captured.update(kwargs)
resp = MagicMock()
resp.content = [MagicMock(text='[{"id": 1}]')]
return resp
mock_client = MagicMock()
mock_client.messages.create = mock_create
from src.scoring import _call_claude
result = _call_claude(mock_client, "system", "user content", model="claude-opus-4-6")
assert captured["model"] == "claude-opus-4-6"
assert result == [{"id": 1}]
def test_no_json_returns_empty(self):
mock_client = MagicMock()
resp = MagicMock()
resp.content = [MagicMock(text="I cannot process this request.")]
mock_client.messages.create.return_value = resp
from src.scoring import _call_claude
result = _call_claude(mock_client, "system", "user", model="claude-haiku-4-5-20251001")
assert result == []
def test_model_is_required_keyword(self):
"""model is keyword-only — calling without it should TypeError."""
mock_client = MagicMock()
from src.scoring import _call_claude
with pytest.raises(TypeError):
_call_claude(mock_client, "system", "user")
# ---------------------------------------------------------------------------
# 9. Full score_run → rescore_top pipeline flow
# ---------------------------------------------------------------------------
class TestFullPipelineFlow:
"""End-to-end: bulk score with haiku, rescore top with sonnet."""
def test_score_then_rescore(self, test_db):
import src.config as cfg
cfg.ANTHROPIC_API_KEY = "sk-ant-test"
cfg.SCORING_MODEL = "claude-haiku-4-5-20251001"
cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929"
cfg.RESCORE_TOP_N = 2
cfg.BATCH_SIZE = 20
run_id = _create_test_run(test_db, "aiml")
# Insert 5 unscored papers
for i in range(5):
_insert_test_papers(test_db, run_id, "aiml", [{
"arxiv_id": f"2601.{i:05d}",
"title": f"Paper {i}",
"abstract": f"Abstract for paper {i}",
"authors": ["Author"],
"categories": ["cs.LG"],
"source": "arxiv",
}])
call_log = []
def mock_create(**kwargs):
model = kwargs["model"]
call_log.append(model)
content = kwargs["messages"][0]["content"]
# Parse which arxiv_ids are in this batch
ids = []
for line in content.split("\n"):
if line.startswith("arxiv_id: "):
ids.append(line.split(": ", 1)[1])
results = []
for aid in ids:
idx = int(aid.split(".")[-1])
results.append({
"arxiv_id": aid,
"code_and_weights": 5 + idx,
"novelty": 6 + idx,
"practical_applicability": 4 + idx,
"summary": f"summary-{model}",
"reasoning": "r",
"code_url": None,
})
resp = MagicMock()
resp.content = [MagicMock(text=json.dumps(results))]
return resp
with patch("anthropic.Anthropic") as MockClient:
mock_instance = MagicMock()
mock_instance.messages.create = mock_create
MockClient.return_value = mock_instance
from src.scoring import rescore_top, score_run
# Step 1: Bulk score
scored = score_run(run_id, "aiml")
assert scored == 5
# Step 2: Rescore top 2
rescored = rescore_top(run_id, "aiml")
assert rescored == 2
# Verify haiku was used for bulk, sonnet for rescore
assert call_log[0] == "claude-haiku-4-5-20251001"
assert call_log[1] == "claude-sonnet-4-5-20250929"
# Verify the top 2 papers have the sonnet summary
from src.db import get_top_papers
top = get_top_papers("aiml", run_id=run_id, limit=5)
# Top 2 should have sonnet summary, rest haiku
assert top[0]["summary"] == "summary-claude-sonnet-4-5-20250929"
assert top[1]["summary"] == "summary-claude-sonnet-4-5-20250929"
assert top[2]["summary"] == "summary-claude-haiku-4-5-20251001"
# ---------------------------------------------------------------------------
# 10. API key validation uses haiku
# ---------------------------------------------------------------------------
class TestApiKeyValidation:
"""Verify the setup wizard key validation always uses haiku."""
@pytest.mark.anyio
async def test_validate_uses_haiku(self):
captured = {}
def mock_create(**kwargs):
captured.update(kwargs)
resp = MagicMock()
resp.content = [MagicMock(text="Hi")]
return resp
with patch("anthropic.Anthropic") as MockClient:
mock_instance = MagicMock()
mock_instance.messages.create = mock_create
MockClient.return_value = mock_instance
from httpx import ASGITransport, AsyncClient
from src.web.app import app
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/setup/validate-key",
json={"api_key": "sk-ant-test-123"},
)
assert resp.status_code == 200
assert captured["model"] == "claude-haiku-4-5-20251001"
# ---------------------------------------------------------------------------
# 11. Setup save persists scoring block
# ---------------------------------------------------------------------------
class TestSetupSave:
"""Verify the setup save endpoint persists the scoring config."""
@pytest.mark.anyio
async def test_save_persists_scoring_block(self, tmp_path):
import src.config as cfg
cfg.CONFIG_PATH = tmp_path / "config.yaml"
cfg.DB_PATH = tmp_path / "researcher.db"
from src.db import init_db
init_db()
from httpx import ASGITransport, AsyncClient
from src.web.app import app
# Mock reschedule since apscheduler may not be installed in test env
with patch("src.web.app.save_setup.__module__", "src.web.app"), \
patch.dict("sys.modules", {"apscheduler": MagicMock(), "apscheduler.schedulers": MagicMock(), "apscheduler.schedulers.background": MagicMock(), "apscheduler.triggers": MagicMock(), "apscheduler.triggers.cron": MagicMock()}):
# Patch reschedule at the call site
with patch("src.scheduler.reschedule", return_value=None):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post("/api/setup/save", json={
"api_key": "",
"scoring": {
"model": "claude-opus-4-6",
"rescore_model": "claude-sonnet-4-5-20250929",
"rescore_top_n": 10,
},
"domains": {
"aiml": {"enabled": True},
"security": {"enabled": True},
},
"schedule": "0 22 * * 0",
})
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
# Verify config was saved and reloaded
assert cfg.SCORING_MODEL == "claude-opus-4-6"
assert cfg.RESCORE_MODEL == "claude-sonnet-4-5-20250929"
assert cfg.RESCORE_TOP_N == 10
# Verify YAML file has the scoring block
import yaml
saved = yaml.safe_load(cfg.CONFIG_PATH.read_text())
assert saved["scoring"]["model"] == "claude-opus-4-6"
assert saved["scoring"]["rescore_top_n"] == 10