debugZero / training /grpo_train.py
The-Fool-09's picture
Upload folder using huggingface_hub
51457b7 verified
from __future__ import annotations
import importlib.util
import json
import math
import os
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from datasets import Dataset
try:
from unsloth import FastLanguageModel, is_bfloat16_supported
HAS_UNSLOTH = True
except ImportError:
HAS_UNSLOTH = False
def is_bfloat16_supported() -> bool:
return False
try:
from unsloth import PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
except ImportError:
pass
try:
from bug_bank import BugBank, BugSample, build_bug_bank
from seed_bank import SEED_BANK, SeedSpec, get_seed_by_id
from server.bug_injector import infer_bug_operator
from server.executor import execute_code
from server.graders import (
compute_ast_distance,
compute_proposer_reward,
compute_solver_reward,
is_effectively_unchanged,
reset_reward_history,
)
from training.dual_role_sampler import sample_proposer_prompt, sample_solver_prompt
except ImportError:
from ..bug_bank import BugBank, BugSample, build_bug_bank
from ..seed_bank import SEED_BANK, SeedSpec, get_seed_by_id
from ..server.bug_injector import infer_bug_operator
from ..server.executor import execute_code
from ..server.graders import (
compute_ast_distance,
compute_proposer_reward,
compute_solver_reward,
is_effectively_unchanged,
reset_reward_history,
)
from .dual_role_sampler import sample_proposer_prompt, sample_solver_prompt
DEFAULT_MODEL_ID = "unsloth/Qwen2.5-Coder-3B-Instruct"
DEFAULT_FALLBACK_MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct"
DEFAULT_OUTPUT_DIR = Path("debugzero_model")
DEFAULT_SOLVER_WEIGHT = 2
DEFAULT_NUM_GENERATIONS = 4
DEFAULT_MAX_STEPS = 80
DEFAULT_MAX_PROMPT_LENGTH = 768
DEFAULT_MAX_COMPLETION_LENGTH = 256
DRY_RUN_MAX_STEPS = 2
DEFAULT_PROPOSER_METRICS_PATH = DEFAULT_OUTPUT_DIR / "proposer_metrics.json"
TARGETED_PROMPT_RATIO = 0.75
def extract_python_code(text: str) -> str:
match = re.search(r"```(?:python)?\s(.*?)```", text, flags=re.DOTALL)
if match:
return match.group(1).strip()
return text.strip()
def completion_to_text(completion) -> str:
if isinstance(completion, list) and completion:
item = completion[0]
if isinstance(item, dict):
return item.get("content", "")
return str(item)
return str(completion)
def prompt_to_text(prompt) -> str:
if isinstance(prompt, list):
parts = []
for item in prompt:
if isinstance(item, dict):
parts.append(str(item.get("content", "")))
else:
parts.append(str(item))
return "\n".join(part for part in parts if part)
if isinstance(prompt, dict):
return str(prompt.get("content", ""))
return str(prompt)
def execute_candidate(seed: SeedSpec, candidate_code: str) -> dict[str, object]:
result = execute_code(candidate_code, seed.test)
execution_result = result.output[:500] if result.output else ""
unsafe_code = execution_result.startswith("Unsafe import detected.")
return {
"tests_passed": result.passed,
"syntax_error": result.syntax_error,
"unsafe_code": unsafe_code,
"execution_result": execution_result,
}
def build_mixed_role_dataset(
bug_bank: BugBank,
solver_weight: int = DEFAULT_SOLVER_WEIGHT,
) -> Dataset:
rows: list[dict[str, object]] = []
for bug_sample in bug_bank.train_samples:
prompt_text = sample_solver_prompt(
bug_sample.buggy_code,
bug_sample.execution_result,
mode="concise",
)
rows.append(
{
"prompt": [{"role": "user", "content": prompt_text}],
"role": "solver",
"seed_id": bug_sample.seed_id,
"original_code": bug_sample.original_code,
"buggy_code": bug_sample.buggy_code,
"bug_operator": bug_sample.bug_operator,
"execution_result": bug_sample.execution_result,
}
)
target_proposer_rows = max(1, math.ceil(len(rows) / solver_weight)) if rows else len(SEED_BANK)
proposer_rows = build_weighted_proposer_rows(bug_bank, target_proposer_rows)
rows.extend(proposer_rows)
return Dataset.from_list(rows)
def create_dataset() -> tuple[Dataset, BugBank]:
bug_bank = build_bug_bank()
return build_mixed_role_dataset(bug_bank), bug_bank
def prop_rew(prompts, completions, **kwargs):
rewards: list[float] = []
roles = kwargs.get("role", [])
seed_ids = kwargs.get("seed_id", [])
original_codes = kwargs.get("original_code", [])
for i, completion in enumerate(completions):
role = roles[i] if i < len(roles) else roles[0]
if role != "proposer":
rewards.append(0.0)
continue
seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0]
original_code = original_codes[i] if i < len(original_codes) else original_codes[0]
seed = get_seed_by_id(seed_id)
candidate_code = extract_python_code(completion_to_text(completion))
execution_meta = execute_candidate(seed, candidate_code)
unchanged_code = is_effectively_unchanged(original_code, candidate_code)
changed_but_passing = (
(not unchanged_code)
and execution_meta["tests_passed"]
and (not execution_meta["syntax_error"])
)
proposer_meta = {
"seed_id": seed.seed_id,
"tests_passed": execution_meta["tests_passed"],
"syntax_error": execution_meta["syntax_error"],
"unsafe_code": execution_meta["unsafe_code"],
"unchanged_code": unchanged_code,
"changed_but_passing": changed_but_passing,
"plausibility_score": 0.0,
}
if not execution_meta["syntax_error"]:
proposer_meta["plausibility_score"] = compute_ast_distance(original_code, candidate_code)
rewards.append(compute_proposer_reward(proposer_meta))
return rewards
def solv_rew(prompts, completions, **kwargs):
rewards: list[float] = []
roles = kwargs.get("role", [])
seed_ids = kwargs.get("seed_id", [])
for i, completion in enumerate(completions):
role = roles[i] if i < len(roles) else roles[0]
if role != "solver":
rewards.append(0.0)
continue
seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0]
seed = get_seed_by_id(seed_id)
candidate_code = extract_python_code(completion_to_text(completion))
execution_meta = execute_candidate(seed, candidate_code)
solver_meta = {
"seed_id": seed.seed_id,
"tests_passed": execution_meta["tests_passed"],
"syntax_error": execution_meta["syntax_error"],
"unsafe_code": execution_meta["unsafe_code"],
}
rewards.append(compute_solver_reward(solver_meta))
return rewards
def evaluate_bug_sample(candidate_code: str, bug_sample: BugSample) -> dict[str, object]:
seed = get_seed_by_id(bug_sample.seed_id)
evaluation = execute_candidate(seed, candidate_code)
reward = compute_solver_reward(
{
"seed_id": bug_sample.seed_id,
"tests_passed": evaluation["tests_passed"],
"syntax_error": evaluation["syntax_error"],
"unsafe_code": evaluation["unsafe_code"],
}
)
return {**evaluation, "reward": reward}
def evaluate_solver_fixed_set(model, tokenizer, bug_bank: BugBank) -> dict[str, float]:
results = []
for bug_sample in bug_bank.eval_samples:
prompt = sample_solver_prompt(
bug_sample.buggy_code,
bug_sample.execution_result,
mode="concise",
)
candidate_code = generate_code(model, tokenizer, prompt, do_sample=False)
results.append(evaluate_bug_sample(candidate_code, bug_sample))
return summarize_solver_results(results)
def evaluate_proposer_fixed_set(model, tokenizer) -> dict[str, float]:
results = []
for seed in SEED_BANK:
prompt = sample_proposer_prompt(seed.original_code)
candidate_code = generate_code(model, tokenizer, prompt, do_sample=False)
evaluation = execute_candidate(seed, candidate_code)
unchanged_code = is_effectively_unchanged(seed.original_code, candidate_code)
valid_bug = (not evaluation["tests_passed"]) and (not evaluation["syntax_error"])
changed_but_passing = (
(not unchanged_code)
and evaluation["tests_passed"]
and (not evaluation["syntax_error"])
)
reward = compute_proposer_reward(
{
"seed_id": seed.seed_id,
"tests_passed": evaluation["tests_passed"],
"syntax_error": evaluation["syntax_error"],
"unsafe_code": evaluation["unsafe_code"],
"unchanged_code": unchanged_code,
"changed_but_passing": changed_but_passing,
"plausibility_score": 0.0
if evaluation["syntax_error"]
else compute_ast_distance(seed.original_code, candidate_code),
}
)
results.append(
{
"seed_id": seed.seed_id,
**evaluation,
"reward": reward,
"unchanged_code": unchanged_code,
"valid_bug": valid_bug,
"changed_but_passing": changed_but_passing,
"likely_bug_family": infer_bug_operator(seed.original_code, candidate_code) or "unknown",
}
)
summary = summarize_proposer_results(results)
summary["by_seed"] = summarize_proposer_by_seed(results)
summary["by_bug_family"] = summarize_proposer_by_bug_family(results)
return summary
def summarize_solver_results(results: list[dict[str, object]]) -> dict[str, float]:
total = len(results) or 1
passed = sum(1 for result in results if result["tests_passed"])
syntax_errors = sum(1 for result in results if result["syntax_error"])
mean_reward = sum(float(result["reward"]) for result in results) / total
return {
"pass_rate": passed / total,
"syntax_error_rate": syntax_errors / total,
"mean_reward": mean_reward,
}
def summarize_proposer_results(results: list[dict[str, object]]) -> dict[str, float]:
total = len(results) or 1
bug_rate = sum(
1 for result in results if (not result["tests_passed"]) and (not result["syntax_error"])
)
unchanged = sum(1 for result in results if result.get("unchanged_code"))
changed_but_passing = sum(1 for result in results if result.get("changed_but_passing"))
syntax_errors = sum(1 for result in results if result["syntax_error"])
mean_reward = sum(float(result["reward"]) for result in results) / total
return {
"break_rate": bug_rate / total,
"valid_bug_rate": bug_rate / total,
"unchanged_rate": unchanged / total,
"changed_but_passing_rate": changed_but_passing / total,
"syntax_error_rate": syntax_errors / total,
"mean_reward": mean_reward,
}
def summarize_proposer_by_seed(results: list[dict[str, object]]) -> dict[str, dict[str, float]]:
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
for result in results:
grouped[str(result["seed_id"])].append(result)
summary: dict[str, dict[str, float]] = {}
for seed_id, seed_results in grouped.items():
total = len(seed_results) or 1
summary[seed_id] = {
"valid_bug_rate": sum(1 for item in seed_results if item.get("valid_bug")) / total,
"unchanged_rate": sum(1 for item in seed_results if item.get("unchanged_code")) / total,
"changed_but_passing_rate": sum(
1 for item in seed_results if item.get("changed_but_passing")
)
/ total,
"mean_reward": sum(float(item["reward"]) for item in seed_results) / total,
}
return summary
def summarize_proposer_by_bug_family(results: list[dict[str, object]]) -> dict[str, dict[str, float]]:
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
for result in results:
grouped[str(result.get("likely_bug_family", "unknown"))].append(result)
summary: dict[str, dict[str, float]] = {}
for family, family_results in grouped.items():
total = len(family_results) or 1
summary[family] = {
"count": float(total),
"valid_bug_rate": sum(1 for item in family_results if item.get("valid_bug")) / total,
"mean_reward": sum(float(item["reward"]) for item in family_results) / total,
}
return summary
def build_weighted_proposer_rows(bug_bank: BugBank, target_proposer_rows: int) -> list[dict[str, object]]:
if target_proposer_rows <= 0:
return []
prior_seed_rates = load_prior_seed_break_rates()
operator_counts = Counter(sample.bug_operator for sample in bug_bank.train_samples)
seed_to_operators: dict[str, list[str]] = defaultdict(list)
for sample in bug_bank.train_samples:
seed_to_operators[sample.seed_id].append(sample.bug_operator)
seed_weights = {}
for seed in SEED_BANK:
prior_break_rate = prior_seed_rates.get(seed.seed_id, 0.5)
seed_weights[seed.seed_id] = max(1, 1 + round((1.0 - prior_break_rate) * 2))
rows: list[dict[str, object]] = []
focus_counters = Counter()
ordered_seeds = sorted(SEED_BANK, key=lambda seed: (-seed_weights[seed.seed_id], seed.seed_id))
# Keep every seed represented before adding extra weight to weak seeds.
for seed in SEED_BANK[:target_proposer_rows]:
bug_focus = choose_proposer_bug_focus(
seed.seed_id,
seed_to_operators[seed.seed_id],
operator_counts,
focus_counters,
len(rows),
target_proposer_rows,
)
prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus)
rows.append(
{
"prompt": [{"role": "user", "content": prompt_text}],
"role": "proposer",
"seed_id": seed.seed_id,
"original_code": seed.original_code,
"bug_focus": bug_focus if bug_focus else "",
}
)
while len(rows) < target_proposer_rows:
for seed in ordered_seeds:
extra_weight = max(0, seed_weights[seed.seed_id] - 1)
for _ in range(extra_weight):
if len(rows) >= target_proposer_rows:
break
bug_focus = choose_proposer_bug_focus(
seed.seed_id,
seed_to_operators[seed.seed_id],
operator_counts,
focus_counters,
len(rows),
target_proposer_rows,
)
prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus)
rows.append(
{
"prompt": [{"role": "user", "content": prompt_text}],
"role": "proposer",
"seed_id": seed.seed_id,
"original_code": seed.original_code,
"bug_focus": bug_focus if bug_focus else "",
}
)
if len(rows) >= target_proposer_rows:
break
return rows
def choose_proposer_bug_focus(
seed_id: str,
operators: list[str],
operator_counts: Counter,
focus_counters: Counter,
row_index: int,
total_rows: int,
) -> str | None:
unique_operators = sorted(set(operators), key=lambda op: (operator_counts[op], op))
if not unique_operators:
return None
if row_index >= math.ceil(total_rows * TARGETED_PROMPT_RATIO):
return None
del seed_id
chosen = min(unique_operators, key=lambda op: (focus_counters[op], operator_counts[op], op))
focus_counters[chosen] += 1
return chosen
def load_prior_seed_break_rates() -> dict[str, float]:
if not DEFAULT_PROPOSER_METRICS_PATH.exists():
return {}
try:
data = json.loads(DEFAULT_PROPOSER_METRICS_PATH.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {}
seed_metrics = data.get("post_proposer_metrics", {}).get("by_seed", {})
return {
str(seed_id): float(metrics.get("valid_bug_rate", 0.5))
for seed_id, metrics in seed_metrics.items()
if isinstance(metrics, dict)
}
def save_metrics_artifact(
pre_proposer_metrics: dict[str, object],
post_proposer_metrics: dict[str, object],
) -> Path:
DEFAULT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
artifact = {
"pre_proposer_metrics": pre_proposer_metrics,
"post_proposer_metrics": post_proposer_metrics,
}
DEFAULT_PROPOSER_METRICS_PATH.write_text(
json.dumps(artifact, indent=2, sort_keys=True),
encoding="utf-8",
)
return DEFAULT_PROPOSER_METRICS_PATH
def generate_code(
model,
tokenizer,
prompt: str | list[dict[str, str]],
*,
do_sample: bool,
max_new_tokens: int = DEFAULT_MAX_COMPLETION_LENGTH,
) -> str:
import torch
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
if isinstance(prompt, list):
prompt_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
else:
prompt_text = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
encoded = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=DEFAULT_MAX_PROMPT_LENGTH)
model_device = next(model.parameters()).device
encoded = {key: value.to(model_device) for key, value in encoded.items()}
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
if do_sample:
generation_kwargs["temperature"] = 0.7
generation_kwargs["top_p"] = 0.95
with torch.no_grad():
output = model.generate(**encoded, **generation_kwargs)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
completion = decoded[len(prompt) :] if decoded.startswith(prompt) else decoded
return extract_python_code(completion)
def get_training_profile(dry_run: bool) -> dict[str, int | float | bool | str]:
has_bitsandbytes = importlib.util.find_spec("bitsandbytes") is not None
return {
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"learning_rate": 2e-5,
"max_steps": DRY_RUN_MAX_STEPS if dry_run else DEFAULT_MAX_STEPS,
"num_generations": 2 if dry_run else DEFAULT_NUM_GENERATIONS,
"max_completion_length": DEFAULT_MAX_COMPLETION_LENGTH,
"report_to": "none",
"optim": "adamw_torch" if dry_run or not has_bitsandbytes else "adamw_8bit",
}
def load_training_model_and_tokenizer(
dry_run: bool,
dataset: Dataset,
bug_bank: BugBank,
):
if dry_run:
return build_tiny_local_model_and_tokenizer(dataset, bug_bank)
if HAS_UNSLOTH:
print("Initializing Unsloth FastLanguageModel...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=DEFAULT_MODEL_ID,
max_seq_length=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH,
load_in_4bit=True,
fast_inference=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
return model, tokenizer
# Unsloth is failing to load (e.g., due to Kaggle/Colab CUDA mismatch).
# Falling back to standard HuggingFace PEFT (LoRA).
print("Unsloth not available. Falling back to standard Transformers loading.")
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_FALLBACK_MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(DEFAULT_FALLBACK_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")
peft_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
return model, tokenizer
def build_tiny_local_model_and_tokenizer(dataset: Dataset, bug_bank: BugBank):
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer
from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast
corpus = [prompt_to_text(row["prompt"]) for row in dataset]
corpus.extend(sample.original_code for sample in bug_bank.train_samples)
corpus.extend(sample.buggy_code for sample in bug_bank.train_samples)
corpus.extend(sample.original_code for sample in bug_bank.eval_samples)
corpus.extend(sample.buggy_code for sample in bug_bank.eval_samples)
corpus.extend(seed.test for seed in SEED_BANK)
tokenizer_object = Tokenizer(WordLevel(unk_token="<unk>"))
tokenizer_object.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(
special_tokens=["<pad>", "<bos>", "<eos>", "<unk>"],
min_frequency=1,
)
tokenizer_object.train_from_iterator(corpus, trainer=trainer)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer_object,
bos_token="<bos>",
eos_token="<eos>",
unk_token="<unk>",
pad_token="<pad>",
)
tokenizer.chat_template = (
"{% for message in messages %}"
"{{ message['role'] }}: {{ message['content'] }}\n"
"{% endfor %}"
"{% if add_generation_prompt %}assistant: {% endif %}"
)
model_config = GPT2Config(
vocab_size=tokenizer.vocab_size,
n_positions=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH,
n_ctx=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH,
n_embd=128,
n_layer=2,
n_head=2,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
model = GPT2LMHeadModel(model_config)
return model, tokenizer
def get_trl_classes():
if os.name == "nt" and not sys.flags.utf8_mode:
print("Windows detected. Use `python -X utf8` when running this file locally.")
from trl import GRPOConfig, GRPOTrainer
return GRPOConfig, GRPOTrainer
def create_trainer(model, tokenizer, dataset: Dataset, dry_run: bool):
GRPOConfig, GRPOTrainer = get_trl_classes()
profile = get_training_profile(dry_run)
supported_kwargs = importlib.import_module("inspect").signature(GRPOConfig.__init__).parameters
config_kwargs = {
"output_dir": str(DEFAULT_OUTPUT_DIR),
"per_device_train_batch_size": profile["per_device_train_batch_size"],
"gradient_accumulation_steps": profile["gradient_accumulation_steps"],
"learning_rate": profile["learning_rate"],
"max_steps": profile["max_steps"],
"num_generations": profile["num_generations"],
"max_prompt_length": DEFAULT_MAX_PROMPT_LENGTH,
"max_completion_length": profile["max_completion_length"],
"bf16": (not dry_run) and HAS_UNSLOTH and is_bfloat16_supported(),
"fp16": (not dry_run) and not is_bfloat16_supported(),
"use_cpu": dry_run,
"logging_steps": 1 if dry_run else 5,
"optim": profile["optim"],
"report_to": profile["report_to"],
"disable_tqdm": True,
}
training_args = GRPOConfig(**{k: v for k, v in config_kwargs.items() if k in supported_kwargs})
print(f"Starting GRPO training for {training_args.max_steps} episodes (steps)...")
print("To change the number of episodes, modify 'max_steps' in the training profile.")
return GRPOTrainer(
model=model,
reward_funcs=[prop_rew, solv_rew],
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
def save_results_plot(
pre_solver_metrics: dict[str, float],
post_solver_metrics: dict[str, float],
pre_proposer_metrics: dict[str, float],
post_proposer_metrics: dict[str, float],
log_history: list[dict[str, float]],
) -> Path | None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib is not installed, skipping plot generation.")
return None
DEFAULT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
plot_path = DEFAULT_OUTPUT_DIR / "debugzero_results.png"
figure, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].bar(
["solver pre", "solver post", "proposer pre", "proposer post"],
[
pre_solver_metrics["pass_rate"],
post_solver_metrics["pass_rate"],
pre_proposer_metrics["break_rate"],
post_proposer_metrics["break_rate"],
],
color=["#4f81bd", "#4f81bd", "#c0504d", "#c0504d"],
)
axes[0].set_ylim(0.0, 1.0)
axes[0].set_title("Fixed Eval Rates")
axes[0].set_ylabel("Rate")
steps = [entry["step"] for entry in log_history if "step" in entry]
losses = [entry["loss"] for entry in log_history if "loss" in entry]
if steps and losses:
axes[1].plot(steps[: len(losses)], losses, marker="o")
axes[1].set_title("Training Loss")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Loss")
else:
axes[1].bar(
["solver reward pre", "solver reward post"],
[
pre_solver_metrics["mean_reward"],
post_solver_metrics["mean_reward"],
],
color=["#9bbb59", "#9bbb59"],
)
axes[1].set_title("Solver Mean Reward")
figure.tight_layout()
figure.savefig(plot_path)
plt.close(figure)
return plot_path
def run_workflow(dry_run: bool = False) -> dict[str, object]:
dataset, bug_bank = create_dataset()
print(
f"Built dataset with {len(dataset)} rows from "
f"{len(bug_bank.train_samples)} training bugs and {len(bug_bank.eval_samples)} eval bugs."
)
model, tokenizer = load_training_model_and_tokenizer(dry_run, dataset, bug_bank)
trainer = create_trainer(model, tokenizer, dataset, dry_run)
reset_reward_history()
pre_solver_metrics = evaluate_solver_fixed_set(model, tokenizer, bug_bank)
pre_proposer_metrics = evaluate_proposer_fixed_set(model, tokenizer)
print("Pre-training solver metrics:", pre_solver_metrics)
print("Pre-training proposer metrics:", pre_proposer_metrics)
reset_reward_history()
train_result = trainer.train()
post_solver_metrics = evaluate_solver_fixed_set(trainer.model, tokenizer, bug_bank)
post_proposer_metrics = evaluate_proposer_fixed_set(trainer.model, tokenizer)
plot_path = save_results_plot(
pre_solver_metrics,
post_solver_metrics,
pre_proposer_metrics,
post_proposer_metrics,
trainer.state.log_history,
)
metrics_artifact_path = save_metrics_artifact(pre_proposer_metrics, post_proposer_metrics)
results = {
"train_result": train_result,
"pre_solver_metrics": pre_solver_metrics,
"post_solver_metrics": post_solver_metrics,
"pre_proposer_metrics": pre_proposer_metrics,
"post_proposer_metrics": post_proposer_metrics,
"plot_path": str(plot_path) if plot_path else None,
"metrics_artifact_path": str(metrics_artifact_path),
"dataset_size": len(dataset),
"train_bug_count": len(bug_bank.train_samples),
"eval_bug_count": len(bug_bank.eval_samples),
}
print("Post-training solver metrics:", post_solver_metrics)
print("Post-training proposer metrics:", post_proposer_metrics)
if plot_path:
print(f"Saved plot to {plot_path}")
print(f"Saved proposer metrics to {metrics_artifact_path}")
return results
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dry_run", action="store_true", help="Run a tiny local GRPO smoke test.")
args = parser.parse_args()
run_workflow(dry_run=args.dry_run)
if __name__ == "__main__":
main()