| """E2E-style smoke coverage for the GRPO training notebook.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| from sql_env.training import rollout as rollout_module |
| from sql_env.training.config import GRPOConfig |
| from sql_env.training.notebook_pipeline import ( |
| build_trainer, |
| run_training_with_metrics, |
| sample_random_baseline, |
| ) |
| from sql_env.training.data_loading import filter_questions_by_difficulty |
| from sql_env.training.rewards import ( |
| reward_correctness, |
| reward_operational, |
| reward_progress, |
| ) |
| from sql_env.training.rollout import rollout_func |
|
|
|
|
| NOTEBOOK_PATH = Path("notebooks/train_grpo.ipynb") |
|
|
|
|
| def _read_notebook() -> dict: |
| return json.loads(NOTEBOOK_PATH.read_text(encoding="utf-8")) |
|
|
|
|
| def _code_sources(notebook: dict) -> list[str]: |
| cells = notebook.get("cells", []) |
| return [ |
| "".join(cell.get("source", [])) |
| for cell in cells |
| if cell.get("cell_type") == "code" |
| ] |
|
|
|
|
| def test_training_notebook_smoke_structure() -> None: |
| """Notebook includes the core GRPO training flow cells.""" |
|
|
| assert NOTEBOOK_PATH.exists(), "notebooks/train_grpo.ipynb must exist" |
|
|
| notebook = _read_notebook() |
| sources = "\n".join(_code_sources(notebook)) |
|
|
| assert "GRPOConfig(" in sources |
| assert "load_model_and_tokenizer(config.model_name)" in sources |
| assert "grpo_trainer_cls=GRPOTrainer" in sources |
| assert "run_training_with_metrics" in sources |
| assert "matplotlib.pyplot as plt" in sources |
|
|
| before_index = sources.find("before_rollouts = sample_random_baseline") |
| train_index = sources.find("run_training_with_metrics(trainer)") |
| assert before_index != -1 |
| assert train_index != -1 |
| assert before_index < train_index |
|
|
|
|
| def test_question_filtering_by_difficulty() -> None: |
| """Difficulty filtering keeps only questions in the allowed set.""" |
|
|
| questions = [ |
| {"question_text": "q1", "difficulty": "easy"}, |
| {"question_text": "q2", "difficulty": "medium"}, |
| {"question_text": "q3", "difficulty": "hard"}, |
| ] |
|
|
| filtered = filter_questions_by_difficulty(questions, ["easy"]) |
| assert [item["question_text"] for item in filtered] == ["q1"] |
|
|
|
|
| class _FakeTokenizer: |
| def apply_chat_template( |
| self, |
| messages: list[dict[str, str]], |
| tokenize: bool = False, |
| add_generation_prompt: bool = True, |
| ) -> str: |
| del messages |
| del tokenize |
| del add_generation_prompt |
| return "prompt" |
|
|
|
|
| class _FakeModel: |
| def __init__(self) -> None: |
| self._count = 0 |
|
|
| def generate(self, prompt: str, max_new_tokens: int) -> str: |
| del prompt |
| del max_new_tokens |
| self._count += 1 |
| if self._count == 1: |
| return "QUERY: SELECT 1" |
| return "ANSWER: 42" |
|
|
|
|
| class _FakeEnvironment: |
| def __init__(self, step_budget: int) -> None: |
| self.step_budget = step_budget |
| self.step_count = 0 |
| self.state = type("State", (), {"episode_id": "ep-e2e"})() |
|
|
| def reset(self, *, seed: int | None = None): |
| del seed |
| self.step_count = 0 |
| return self._observation(done=False, result="") |
|
|
| def step(self, action): |
| self.step_count += 1 |
| if getattr(action, "action_type", "") == "ANSWER": |
| return self._observation( |
| done=True, result="Answer submitted: correct.", reward=1.0 |
| ) |
| return self._observation(done=False, result="ok", reward=0.1) |
|
|
| def _observation(self, done: bool, result: str, reward: float | None = 0.0): |
| from sql_env.models import SQLObservation |
|
|
| return SQLObservation( |
| question="How many rows?", |
| schema_info="Available tables:\n- t", |
| result=result, |
| error="", |
| step_count=self.step_count, |
| budget_remaining=max(0, self.step_budget - self.step_count), |
| action_history=[], |
| done=done, |
| reward=reward, |
| ) |
|
|
|
|
| def test_training_pipeline_smoke(monkeypatch) -> None: |
| """Happy-path rollout + reward computation produces trainable signals.""" |
|
|
| config = GRPOConfig( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| output_dir="outputs/grpo_test", |
| step_budget=2, |
| ) |
| tokenizer = _FakeTokenizer() |
| model = _FakeModel() |
| fake_env = _FakeEnvironment(step_budget=2) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| rollouts = rollout_func(["Count rows"], model, tokenizer, config) |
| assert len(rollouts) == 1 |
|
|
| metadata = [item["metadata"] for item in rollouts] |
| completions = [ |
| [{"role": "assistant", "content": item["content"]}] for item in rollouts |
| ] |
|
|
| correctness = reward_correctness(completions, metadata=metadata) |
| progress = reward_progress(completions, metadata=metadata) |
| operational = reward_operational(completions, metadata=metadata) |
|
|
| assert correctness == [1.0] |
| assert len(progress) == 1 |
| assert 0.0 <= progress[0] <= 1.0 |
| assert len(operational) == 1 |
|
|
|
|
| class _FakeTRLConfig: |
| def __init__(self, **kwargs): |
| self.kwargs = kwargs |
|
|
|
|
| class _FakeTrainer: |
| def __init__( |
| self, |
| *, |
| model, |
| processing_class, |
| args, |
| train_dataset, |
| reward_funcs, |
| ) -> None: |
| self.model = model |
| self.processing_class = processing_class |
| self.args = args |
| self.train_dataset = train_dataset |
| self.reward_funcs = reward_funcs |
| self.state = type("State", (), {"log_history": []})() |
| self.train_called = False |
|
|
| def train(self) -> dict[str, str]: |
| self.train_called = True |
| self.state.log_history = [{"step": 1, "reward": 0.25}] |
| return {"status": "ok"} |
|
|
|
|
| def test_notebook_pipeline_executes_training_step(monkeypatch) -> None: |
| """Notebook pipeline helper builds trainer and executes train().""" |
|
|
| config = GRPOConfig( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| output_dir="outputs/grpo_test", |
| step_budget=2, |
| ) |
| tokenizer = _FakeTokenizer() |
| model = _FakeModel() |
| fake_env = _FakeEnvironment(step_budget=2) |
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| trainer = build_trainer( |
| model=model, |
| tokenizer=tokenizer, |
| prompts=[{"prompt": "Count rows"}], |
| config=config, |
| trl_grpo_config_cls=_FakeTRLConfig, |
| grpo_trainer_cls=_FakeTrainer, |
| reward_funcs=[reward_correctness, reward_progress, reward_operational], |
| ) |
|
|
| output, steps, rewards = run_training_with_metrics(trainer) |
|
|
| assert trainer.train_called is True |
| assert output == {"status": "ok"} |
| assert steps == [1] |
| assert rewards == [0.25] |
|
|
|
|
| def test_random_baseline_transcripts_are_generated() -> None: |
| """Random baseline helper generates readable transcripts per prompt.""" |
|
|
| baseline = sample_random_baseline(["q1", "q2"], step_budget=3, seed=7) |
| assert len(baseline) == 2 |
| assert all(item["metadata"]["policy"] == "random" for item in baseline) |
| assert all(item["completion"] for item in baseline) |
|
|