| """Notebook-oriented helpers for GRPO training orchestration.""" |
|
|
| from __future__ import annotations |
|
|
| import random |
| from typing import Any |
|
|
|
|
| def _precision_kwargs(precision: str) -> dict[str, bool]: |
| """Map precision string to TRL config kwargs.""" |
| if precision == "fp16": |
| return {"fp16": True, "bf16": False} |
| if precision == "bf16": |
| return {"fp16": False, "bf16": True} |
| if precision == "fp32": |
| return {"fp16": False, "bf16": False} |
| |
| return {} |
|
|
|
|
| def sample_random_baseline( |
| prompts: list[str], |
| *, |
| step_budget: int, |
| seed: int, |
| ) -> list[dict[str, Any]]: |
| """Generate simple random-action transcripts for baseline comparison.""" |
|
|
| rng = random.Random(seed) |
| action_types = ["DESCRIBE", "SAMPLE", "QUERY", "ANSWER"] |
| transcripts: list[dict[str, Any]] = [] |
|
|
| for prompt in prompts: |
| step_count = max(1, min(step_budget, 5)) |
| lines = [] |
| for _ in range(step_count): |
| action = rng.choice(action_types) |
| argument = "table_1" if action != "QUERY" else "SELECT 1" |
| lines.append(f"{action}: {argument}") |
|
|
| transcripts.append( |
| { |
| "prompt": prompt, |
| "completion": "\n".join(lines), |
| "content": "\n".join(lines), |
| "metadata": {"policy": "random", "step_count": step_count}, |
| } |
| ) |
|
|
| return transcripts |
|
|
|
|
| def build_trainer( |
| *, |
| model: Any, |
| tokenizer: Any, |
| prompts: list[str], |
| config: Any, |
| trl_grpo_config_cls: type, |
| grpo_trainer_cls: type, |
| reward_funcs: list[Any], |
| environment_factory: type | None = None, |
| callbacks: list[Any] | None = None, |
| ) -> Any: |
| """Build a GRPO trainer instance using notebook config objects.""" |
|
|
| extra_kwargs: dict[str, Any] = {} |
| if getattr(config, "gradient_checkpointing", False): |
| extra_kwargs["gradient_checkpointing"] = True |
|
|
| trainer_config = trl_grpo_config_cls( |
| output_dir=config.output_dir, |
| learning_rate=config.learning_rate, |
| per_device_train_batch_size=config.per_device_train_batch_size, |
| gradient_accumulation_steps=config.gradient_accumulation_steps, |
| num_train_epochs=config.num_train_epochs, |
| logging_steps=config.logging_steps, |
| max_completion_length=config.max_new_tokens, |
| num_generations=config.num_generations, |
| generation_batch_size=config.num_generations, |
| beta=getattr(config, "beta", 0.04), |
| **_precision_kwargs(getattr(config, "precision", "auto")), |
| **extra_kwargs, |
| remove_unused_columns=False, |
| log_completions=True, |
| num_completions_to_print=1, |
| chat_template_kwargs={ |
| "enable_thinking": getattr(config, "enable_thinking", False), |
| }, |
| ) |
|
|
| trainer_kwargs: dict[str, Any] = { |
| "model": model, |
| "processing_class": tokenizer, |
| "args": trainer_config, |
| "train_dataset": prompts, |
| "reward_funcs": reward_funcs, |
| } |
|
|
| if environment_factory is not None: |
| configure = getattr(environment_factory, "configure", None) |
| if not callable(configure): |
| configure = getattr(environment_factory, "_configure", None) |
| if callable(configure): |
| configure( |
| questions_path=config.questions_path, |
| db_dir=config.db_dir, |
| step_budget=config.step_budget, |
| ) |
| trainer_kwargs["environment_factory"] = environment_factory |
|
|
| if callbacks is not None: |
| trainer_kwargs["callbacks"] = callbacks |
|
|
| return grpo_trainer_cls( |
| **trainer_kwargs, |
| ) |
|
|
|
|
| def run_training_with_metrics(trainer: Any) -> tuple[Any, list[int], list[float]]: |
| """Run trainer.train() and extract plotting-friendly step/reward vectors.""" |
|
|
| train_output = trainer.train() |
|
|
| log_history: list[dict[str, Any]] = [] |
| if hasattr(trainer, "state") and hasattr(trainer.state, "log_history"): |
| maybe_history = trainer.state.log_history |
| if isinstance(maybe_history, list): |
| log_history = maybe_history |
|
|
| steps: list[int] = [] |
| rewards: list[float] = [] |
| for item in log_history: |
| if not isinstance(item, dict): |
| continue |
| if "step" not in item or "reward" not in item: |
| continue |
| steps.append(int(item["step"])) |
| rewards.append(float(item["reward"])) |
|
|
| return train_output, steps, rewards |
|
|
|
|
| def format_oom_guidance(error: Exception) -> str: |
| """Return actionable guidance when training hits OOM.""" |
|
|
| return ( |
| f"Training failed with OOM: {error}. " |
| "Try reducing per_device_train_batch_size or num_generations." |
| ) |
|
|