Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import importlib.util | |
| import json | |
| import math | |
| import os | |
| import re | |
| import sys | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from datasets import Dataset | |
| try: | |
| from unsloth import FastLanguageModel, is_bfloat16_supported | |
| HAS_UNSLOTH = True | |
| except ImportError: | |
| HAS_UNSLOTH = False | |
| def is_bfloat16_supported() -> bool: | |
| return False | |
| try: | |
| from unsloth import PatchFastRL | |
| PatchFastRL("GRPO", FastLanguageModel) | |
| except ImportError: | |
| pass | |
| try: | |
| from bug_bank import BugBank, BugSample, build_bug_bank | |
| from seed_bank import SEED_BANK, SeedSpec, get_seed_by_id | |
| from server.bug_injector import infer_bug_operator | |
| from server.executor import execute_code | |
| from server.graders import ( | |
| compute_ast_distance, | |
| compute_proposer_reward, | |
| compute_solver_reward, | |
| is_effectively_unchanged, | |
| reset_reward_history, | |
| ) | |
| from training.dual_role_sampler import sample_proposer_prompt, sample_solver_prompt | |
| except ImportError: | |
| from ..bug_bank import BugBank, BugSample, build_bug_bank | |
| from ..seed_bank import SEED_BANK, SeedSpec, get_seed_by_id | |
| from ..server.bug_injector import infer_bug_operator | |
| from ..server.executor import execute_code | |
| from ..server.graders import ( | |
| compute_ast_distance, | |
| compute_proposer_reward, | |
| compute_solver_reward, | |
| is_effectively_unchanged, | |
| reset_reward_history, | |
| ) | |
| from .dual_role_sampler import sample_proposer_prompt, sample_solver_prompt | |
| DEFAULT_MODEL_ID = "unsloth/Qwen2.5-Coder-3B-Instruct" | |
| DEFAULT_FALLBACK_MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct" | |
| DEFAULT_OUTPUT_DIR = Path("debugzero_model") | |
| DEFAULT_SOLVER_WEIGHT = 2 | |
| DEFAULT_NUM_GENERATIONS = 4 | |
| DEFAULT_MAX_STEPS = 80 | |
| DEFAULT_MAX_PROMPT_LENGTH = 768 | |
| DEFAULT_MAX_COMPLETION_LENGTH = 256 | |
| DRY_RUN_MAX_STEPS = 2 | |
| DEFAULT_PROPOSER_METRICS_PATH = DEFAULT_OUTPUT_DIR / "proposer_metrics.json" | |
| TARGETED_PROMPT_RATIO = 0.75 | |
| def extract_python_code(text: str) -> str: | |
| match = re.search(r"```(?:python)?\s(.*?)```", text, flags=re.DOTALL) | |
| if match: | |
| return match.group(1).strip() | |
| return text.strip() | |
| def completion_to_text(completion) -> str: | |
| if isinstance(completion, list) and completion: | |
| item = completion[0] | |
| if isinstance(item, dict): | |
| return item.get("content", "") | |
| return str(item) | |
| return str(completion) | |
| def prompt_to_text(prompt) -> str: | |
| if isinstance(prompt, list): | |
| parts = [] | |
| for item in prompt: | |
| if isinstance(item, dict): | |
| parts.append(str(item.get("content", ""))) | |
| else: | |
| parts.append(str(item)) | |
| return "\n".join(part for part in parts if part) | |
| if isinstance(prompt, dict): | |
| return str(prompt.get("content", "")) | |
| return str(prompt) | |
| def execute_candidate(seed: SeedSpec, candidate_code: str) -> dict[str, object]: | |
| result = execute_code(candidate_code, seed.test) | |
| execution_result = result.output[:500] if result.output else "" | |
| unsafe_code = execution_result.startswith("Unsafe import detected.") | |
| return { | |
| "tests_passed": result.passed, | |
| "syntax_error": result.syntax_error, | |
| "unsafe_code": unsafe_code, | |
| "execution_result": execution_result, | |
| } | |
| def build_mixed_role_dataset( | |
| bug_bank: BugBank, | |
| solver_weight: int = DEFAULT_SOLVER_WEIGHT, | |
| ) -> Dataset: | |
| rows: list[dict[str, object]] = [] | |
| for bug_sample in bug_bank.train_samples: | |
| prompt_text = sample_solver_prompt( | |
| bug_sample.buggy_code, | |
| bug_sample.execution_result, | |
| mode="concise", | |
| ) | |
| rows.append( | |
| { | |
| "prompt": [{"role": "user", "content": prompt_text}], | |
| "role": "solver", | |
| "seed_id": bug_sample.seed_id, | |
| "original_code": bug_sample.original_code, | |
| "buggy_code": bug_sample.buggy_code, | |
| "bug_operator": bug_sample.bug_operator, | |
| "execution_result": bug_sample.execution_result, | |
| } | |
| ) | |
| target_proposer_rows = max(1, math.ceil(len(rows) / solver_weight)) if rows else len(SEED_BANK) | |
| proposer_rows = build_weighted_proposer_rows(bug_bank, target_proposer_rows) | |
| rows.extend(proposer_rows) | |
| return Dataset.from_list(rows) | |
| def create_dataset() -> tuple[Dataset, BugBank]: | |
| bug_bank = build_bug_bank() | |
| return build_mixed_role_dataset(bug_bank), bug_bank | |
| def prop_rew(prompts, completions, **kwargs): | |
| rewards: list[float] = [] | |
| roles = kwargs.get("role", []) | |
| seed_ids = kwargs.get("seed_id", []) | |
| original_codes = kwargs.get("original_code", []) | |
| for i, completion in enumerate(completions): | |
| role = roles[i] if i < len(roles) else roles[0] | |
| if role != "proposer": | |
| rewards.append(0.0) | |
| continue | |
| seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0] | |
| original_code = original_codes[i] if i < len(original_codes) else original_codes[0] | |
| seed = get_seed_by_id(seed_id) | |
| candidate_code = extract_python_code(completion_to_text(completion)) | |
| execution_meta = execute_candidate(seed, candidate_code) | |
| unchanged_code = is_effectively_unchanged(original_code, candidate_code) | |
| changed_but_passing = ( | |
| (not unchanged_code) | |
| and execution_meta["tests_passed"] | |
| and (not execution_meta["syntax_error"]) | |
| ) | |
| proposer_meta = { | |
| "seed_id": seed.seed_id, | |
| "tests_passed": execution_meta["tests_passed"], | |
| "syntax_error": execution_meta["syntax_error"], | |
| "unsafe_code": execution_meta["unsafe_code"], | |
| "unchanged_code": unchanged_code, | |
| "changed_but_passing": changed_but_passing, | |
| "plausibility_score": 0.0, | |
| } | |
| if not execution_meta["syntax_error"]: | |
| proposer_meta["plausibility_score"] = compute_ast_distance(original_code, candidate_code) | |
| rewards.append(compute_proposer_reward(proposer_meta)) | |
| return rewards | |
| def solv_rew(prompts, completions, **kwargs): | |
| rewards: list[float] = [] | |
| roles = kwargs.get("role", []) | |
| seed_ids = kwargs.get("seed_id", []) | |
| for i, completion in enumerate(completions): | |
| role = roles[i] if i < len(roles) else roles[0] | |
| if role != "solver": | |
| rewards.append(0.0) | |
| continue | |
| seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0] | |
| seed = get_seed_by_id(seed_id) | |
| candidate_code = extract_python_code(completion_to_text(completion)) | |
| execution_meta = execute_candidate(seed, candidate_code) | |
| solver_meta = { | |
| "seed_id": seed.seed_id, | |
| "tests_passed": execution_meta["tests_passed"], | |
| "syntax_error": execution_meta["syntax_error"], | |
| "unsafe_code": execution_meta["unsafe_code"], | |
| } | |
| rewards.append(compute_solver_reward(solver_meta)) | |
| return rewards | |
| def evaluate_bug_sample(candidate_code: str, bug_sample: BugSample) -> dict[str, object]: | |
| seed = get_seed_by_id(bug_sample.seed_id) | |
| evaluation = execute_candidate(seed, candidate_code) | |
| reward = compute_solver_reward( | |
| { | |
| "seed_id": bug_sample.seed_id, | |
| "tests_passed": evaluation["tests_passed"], | |
| "syntax_error": evaluation["syntax_error"], | |
| "unsafe_code": evaluation["unsafe_code"], | |
| } | |
| ) | |
| return {**evaluation, "reward": reward} | |
| def evaluate_solver_fixed_set(model, tokenizer, bug_bank: BugBank) -> dict[str, float]: | |
| results = [] | |
| for bug_sample in bug_bank.eval_samples: | |
| prompt = sample_solver_prompt( | |
| bug_sample.buggy_code, | |
| bug_sample.execution_result, | |
| mode="concise", | |
| ) | |
| candidate_code = generate_code(model, tokenizer, prompt, do_sample=False) | |
| results.append(evaluate_bug_sample(candidate_code, bug_sample)) | |
| return summarize_solver_results(results) | |
| def evaluate_proposer_fixed_set(model, tokenizer) -> dict[str, float]: | |
| results = [] | |
| for seed in SEED_BANK: | |
| prompt = sample_proposer_prompt(seed.original_code) | |
| candidate_code = generate_code(model, tokenizer, prompt, do_sample=False) | |
| evaluation = execute_candidate(seed, candidate_code) | |
| unchanged_code = is_effectively_unchanged(seed.original_code, candidate_code) | |
| valid_bug = (not evaluation["tests_passed"]) and (not evaluation["syntax_error"]) | |
| changed_but_passing = ( | |
| (not unchanged_code) | |
| and evaluation["tests_passed"] | |
| and (not evaluation["syntax_error"]) | |
| ) | |
| reward = compute_proposer_reward( | |
| { | |
| "seed_id": seed.seed_id, | |
| "tests_passed": evaluation["tests_passed"], | |
| "syntax_error": evaluation["syntax_error"], | |
| "unsafe_code": evaluation["unsafe_code"], | |
| "unchanged_code": unchanged_code, | |
| "changed_but_passing": changed_but_passing, | |
| "plausibility_score": 0.0 | |
| if evaluation["syntax_error"] | |
| else compute_ast_distance(seed.original_code, candidate_code), | |
| } | |
| ) | |
| results.append( | |
| { | |
| "seed_id": seed.seed_id, | |
| **evaluation, | |
| "reward": reward, | |
| "unchanged_code": unchanged_code, | |
| "valid_bug": valid_bug, | |
| "changed_but_passing": changed_but_passing, | |
| "likely_bug_family": infer_bug_operator(seed.original_code, candidate_code) or "unknown", | |
| } | |
| ) | |
| summary = summarize_proposer_results(results) | |
| summary["by_seed"] = summarize_proposer_by_seed(results) | |
| summary["by_bug_family"] = summarize_proposer_by_bug_family(results) | |
| return summary | |
| def summarize_solver_results(results: list[dict[str, object]]) -> dict[str, float]: | |
| total = len(results) or 1 | |
| passed = sum(1 for result in results if result["tests_passed"]) | |
| syntax_errors = sum(1 for result in results if result["syntax_error"]) | |
| mean_reward = sum(float(result["reward"]) for result in results) / total | |
| return { | |
| "pass_rate": passed / total, | |
| "syntax_error_rate": syntax_errors / total, | |
| "mean_reward": mean_reward, | |
| } | |
| def summarize_proposer_results(results: list[dict[str, object]]) -> dict[str, float]: | |
| total = len(results) or 1 | |
| bug_rate = sum( | |
| 1 for result in results if (not result["tests_passed"]) and (not result["syntax_error"]) | |
| ) | |
| unchanged = sum(1 for result in results if result.get("unchanged_code")) | |
| changed_but_passing = sum(1 for result in results if result.get("changed_but_passing")) | |
| syntax_errors = sum(1 for result in results if result["syntax_error"]) | |
| mean_reward = sum(float(result["reward"]) for result in results) / total | |
| return { | |
| "break_rate": bug_rate / total, | |
| "valid_bug_rate": bug_rate / total, | |
| "unchanged_rate": unchanged / total, | |
| "changed_but_passing_rate": changed_but_passing / total, | |
| "syntax_error_rate": syntax_errors / total, | |
| "mean_reward": mean_reward, | |
| } | |
| def summarize_proposer_by_seed(results: list[dict[str, object]]) -> dict[str, dict[str, float]]: | |
| grouped: dict[str, list[dict[str, object]]] = defaultdict(list) | |
| for result in results: | |
| grouped[str(result["seed_id"])].append(result) | |
| summary: dict[str, dict[str, float]] = {} | |
| for seed_id, seed_results in grouped.items(): | |
| total = len(seed_results) or 1 | |
| summary[seed_id] = { | |
| "valid_bug_rate": sum(1 for item in seed_results if item.get("valid_bug")) / total, | |
| "unchanged_rate": sum(1 for item in seed_results if item.get("unchanged_code")) / total, | |
| "changed_but_passing_rate": sum( | |
| 1 for item in seed_results if item.get("changed_but_passing") | |
| ) | |
| / total, | |
| "mean_reward": sum(float(item["reward"]) for item in seed_results) / total, | |
| } | |
| return summary | |
| def summarize_proposer_by_bug_family(results: list[dict[str, object]]) -> dict[str, dict[str, float]]: | |
| grouped: dict[str, list[dict[str, object]]] = defaultdict(list) | |
| for result in results: | |
| grouped[str(result.get("likely_bug_family", "unknown"))].append(result) | |
| summary: dict[str, dict[str, float]] = {} | |
| for family, family_results in grouped.items(): | |
| total = len(family_results) or 1 | |
| summary[family] = { | |
| "count": float(total), | |
| "valid_bug_rate": sum(1 for item in family_results if item.get("valid_bug")) / total, | |
| "mean_reward": sum(float(item["reward"]) for item in family_results) / total, | |
| } | |
| return summary | |
| def build_weighted_proposer_rows(bug_bank: BugBank, target_proposer_rows: int) -> list[dict[str, object]]: | |
| if target_proposer_rows <= 0: | |
| return [] | |
| prior_seed_rates = load_prior_seed_break_rates() | |
| operator_counts = Counter(sample.bug_operator for sample in bug_bank.train_samples) | |
| seed_to_operators: dict[str, list[str]] = defaultdict(list) | |
| for sample in bug_bank.train_samples: | |
| seed_to_operators[sample.seed_id].append(sample.bug_operator) | |
| seed_weights = {} | |
| for seed in SEED_BANK: | |
| prior_break_rate = prior_seed_rates.get(seed.seed_id, 0.5) | |
| seed_weights[seed.seed_id] = max(1, 1 + round((1.0 - prior_break_rate) * 2)) | |
| rows: list[dict[str, object]] = [] | |
| focus_counters = Counter() | |
| ordered_seeds = sorted(SEED_BANK, key=lambda seed: (-seed_weights[seed.seed_id], seed.seed_id)) | |
| # Keep every seed represented before adding extra weight to weak seeds. | |
| for seed in SEED_BANK[:target_proposer_rows]: | |
| bug_focus = choose_proposer_bug_focus( | |
| seed.seed_id, | |
| seed_to_operators[seed.seed_id], | |
| operator_counts, | |
| focus_counters, | |
| len(rows), | |
| target_proposer_rows, | |
| ) | |
| prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus) | |
| rows.append( | |
| { | |
| "prompt": [{"role": "user", "content": prompt_text}], | |
| "role": "proposer", | |
| "seed_id": seed.seed_id, | |
| "original_code": seed.original_code, | |
| "bug_focus": bug_focus if bug_focus else "", | |
| } | |
| ) | |
| while len(rows) < target_proposer_rows: | |
| for seed in ordered_seeds: | |
| extra_weight = max(0, seed_weights[seed.seed_id] - 1) | |
| for _ in range(extra_weight): | |
| if len(rows) >= target_proposer_rows: | |
| break | |
| bug_focus = choose_proposer_bug_focus( | |
| seed.seed_id, | |
| seed_to_operators[seed.seed_id], | |
| operator_counts, | |
| focus_counters, | |
| len(rows), | |
| target_proposer_rows, | |
| ) | |
| prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus) | |
| rows.append( | |
| { | |
| "prompt": [{"role": "user", "content": prompt_text}], | |
| "role": "proposer", | |
| "seed_id": seed.seed_id, | |
| "original_code": seed.original_code, | |
| "bug_focus": bug_focus if bug_focus else "", | |
| } | |
| ) | |
| if len(rows) >= target_proposer_rows: | |
| break | |
| return rows | |
| def choose_proposer_bug_focus( | |
| seed_id: str, | |
| operators: list[str], | |
| operator_counts: Counter, | |
| focus_counters: Counter, | |
| row_index: int, | |
| total_rows: int, | |
| ) -> str | None: | |
| unique_operators = sorted(set(operators), key=lambda op: (operator_counts[op], op)) | |
| if not unique_operators: | |
| return None | |
| if row_index >= math.ceil(total_rows * TARGETED_PROMPT_RATIO): | |
| return None | |
| del seed_id | |
| chosen = min(unique_operators, key=lambda op: (focus_counters[op], operator_counts[op], op)) | |
| focus_counters[chosen] += 1 | |
| return chosen | |
| def load_prior_seed_break_rates() -> dict[str, float]: | |
| if not DEFAULT_PROPOSER_METRICS_PATH.exists(): | |
| return {} | |
| try: | |
| data = json.loads(DEFAULT_PROPOSER_METRICS_PATH.read_text(encoding="utf-8")) | |
| except (OSError, json.JSONDecodeError): | |
| return {} | |
| seed_metrics = data.get("post_proposer_metrics", {}).get("by_seed", {}) | |
| return { | |
| str(seed_id): float(metrics.get("valid_bug_rate", 0.5)) | |
| for seed_id, metrics in seed_metrics.items() | |
| if isinstance(metrics, dict) | |
| } | |
| def save_metrics_artifact( | |
| pre_proposer_metrics: dict[str, object], | |
| post_proposer_metrics: dict[str, object], | |
| ) -> Path: | |
| DEFAULT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| artifact = { | |
| "pre_proposer_metrics": pre_proposer_metrics, | |
| "post_proposer_metrics": post_proposer_metrics, | |
| } | |
| DEFAULT_PROPOSER_METRICS_PATH.write_text( | |
| json.dumps(artifact, indent=2, sort_keys=True), | |
| encoding="utf-8", | |
| ) | |
| return DEFAULT_PROPOSER_METRICS_PATH | |
| def generate_code( | |
| model, | |
| tokenizer, | |
| prompt: str | list[dict[str, str]], | |
| *, | |
| do_sample: bool, | |
| max_new_tokens: int = DEFAULT_MAX_COMPLETION_LENGTH, | |
| ) -> str: | |
| import torch | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.eval() | |
| if isinstance(prompt, list): | |
| prompt_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) | |
| else: | |
| prompt_text = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) | |
| encoded = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=DEFAULT_MAX_PROMPT_LENGTH) | |
| model_device = next(model.parameters()).device | |
| encoded = {key: value.to(model_device) for key, value in encoded.items()} | |
| generation_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": do_sample, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| "eos_token_id": tokenizer.eos_token_id, | |
| } | |
| if do_sample: | |
| generation_kwargs["temperature"] = 0.7 | |
| generation_kwargs["top_p"] = 0.95 | |
| with torch.no_grad(): | |
| output = model.generate(**encoded, **generation_kwargs) | |
| decoded = tokenizer.decode(output[0], skip_special_tokens=True) | |
| completion = decoded[len(prompt) :] if decoded.startswith(prompt) else decoded | |
| return extract_python_code(completion) | |
| def get_training_profile(dry_run: bool) -> dict[str, int | float | bool | str]: | |
| has_bitsandbytes = importlib.util.find_spec("bitsandbytes") is not None | |
| return { | |
| "per_device_train_batch_size": 1, | |
| "gradient_accumulation_steps": 4, | |
| "learning_rate": 2e-5, | |
| "max_steps": DRY_RUN_MAX_STEPS if dry_run else DEFAULT_MAX_STEPS, | |
| "num_generations": 2 if dry_run else DEFAULT_NUM_GENERATIONS, | |
| "max_completion_length": DEFAULT_MAX_COMPLETION_LENGTH, | |
| "report_to": "none", | |
| "optim": "adamw_torch" if dry_run or not has_bitsandbytes else "adamw_8bit", | |
| } | |
| def load_training_model_and_tokenizer( | |
| dry_run: bool, | |
| dataset: Dataset, | |
| bug_bank: BugBank, | |
| ): | |
| if dry_run: | |
| return build_tiny_local_model_and_tokenizer(dataset, bug_bank) | |
| if HAS_UNSLOTH: | |
| print("Initializing Unsloth FastLanguageModel...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=DEFAULT_MODEL_ID, | |
| max_seq_length=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH, | |
| load_in_4bit=True, | |
| fast_inference=True, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| use_rslora=False, | |
| loftq_config=None, | |
| ) | |
| return model, tokenizer | |
| # Unsloth is failing to load (e.g., due to Kaggle/Colab CUDA mismatch). | |
| # Falling back to standard HuggingFace PEFT (LoRA). | |
| print("Unsloth not available. Falling back to standard Transformers loading.") | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig, get_peft_model | |
| tokenizer = AutoTokenizer.from_pretrained(DEFAULT_FALLBACK_MODEL_ID) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(DEFAULT_FALLBACK_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto") | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=0, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| return model, tokenizer | |
| def build_tiny_local_model_and_tokenizer(dataset: Dataset, bug_bank: BugBank): | |
| from tokenizers import Tokenizer | |
| from tokenizers.models import WordLevel | |
| from tokenizers.pre_tokenizers import Whitespace | |
| from tokenizers.trainers import WordLevelTrainer | |
| from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast | |
| corpus = [prompt_to_text(row["prompt"]) for row in dataset] | |
| corpus.extend(sample.original_code for sample in bug_bank.train_samples) | |
| corpus.extend(sample.buggy_code for sample in bug_bank.train_samples) | |
| corpus.extend(sample.original_code for sample in bug_bank.eval_samples) | |
| corpus.extend(sample.buggy_code for sample in bug_bank.eval_samples) | |
| corpus.extend(seed.test for seed in SEED_BANK) | |
| tokenizer_object = Tokenizer(WordLevel(unk_token="<unk>")) | |
| tokenizer_object.pre_tokenizer = Whitespace() | |
| trainer = WordLevelTrainer( | |
| special_tokens=["<pad>", "<bos>", "<eos>", "<unk>"], | |
| min_frequency=1, | |
| ) | |
| tokenizer_object.train_from_iterator(corpus, trainer=trainer) | |
| tokenizer = PreTrainedTokenizerFast( | |
| tokenizer_object=tokenizer_object, | |
| bos_token="<bos>", | |
| eos_token="<eos>", | |
| unk_token="<unk>", | |
| pad_token="<pad>", | |
| ) | |
| tokenizer.chat_template = ( | |
| "{% for message in messages %}" | |
| "{{ message['role'] }}: {{ message['content'] }}\n" | |
| "{% endfor %}" | |
| "{% if add_generation_prompt %}assistant: {% endif %}" | |
| ) | |
| model_config = GPT2Config( | |
| vocab_size=tokenizer.vocab_size, | |
| n_positions=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH, | |
| n_ctx=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH, | |
| n_embd=128, | |
| n_layer=2, | |
| n_head=2, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| model = GPT2LMHeadModel(model_config) | |
| return model, tokenizer | |
| def get_trl_classes(): | |
| if os.name == "nt" and not sys.flags.utf8_mode: | |
| print("Windows detected. Use `python -X utf8` when running this file locally.") | |
| from trl import GRPOConfig, GRPOTrainer | |
| return GRPOConfig, GRPOTrainer | |
| def create_trainer(model, tokenizer, dataset: Dataset, dry_run: bool): | |
| GRPOConfig, GRPOTrainer = get_trl_classes() | |
| profile = get_training_profile(dry_run) | |
| supported_kwargs = importlib.import_module("inspect").signature(GRPOConfig.__init__).parameters | |
| config_kwargs = { | |
| "output_dir": str(DEFAULT_OUTPUT_DIR), | |
| "per_device_train_batch_size": profile["per_device_train_batch_size"], | |
| "gradient_accumulation_steps": profile["gradient_accumulation_steps"], | |
| "learning_rate": profile["learning_rate"], | |
| "max_steps": profile["max_steps"], | |
| "num_generations": profile["num_generations"], | |
| "max_prompt_length": DEFAULT_MAX_PROMPT_LENGTH, | |
| "max_completion_length": profile["max_completion_length"], | |
| "bf16": (not dry_run) and HAS_UNSLOTH and is_bfloat16_supported(), | |
| "fp16": (not dry_run) and not is_bfloat16_supported(), | |
| "use_cpu": dry_run, | |
| "logging_steps": 1 if dry_run else 5, | |
| "optim": profile["optim"], | |
| "report_to": profile["report_to"], | |
| "disable_tqdm": True, | |
| } | |
| training_args = GRPOConfig(**{k: v for k, v in config_kwargs.items() if k in supported_kwargs}) | |
| print(f"Starting GRPO training for {training_args.max_steps} episodes (steps)...") | |
| print("To change the number of episodes, modify 'max_steps' in the training profile.") | |
| return GRPOTrainer( | |
| model=model, | |
| reward_funcs=[prop_rew, solv_rew], | |
| args=training_args, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| ) | |
| def save_results_plot( | |
| pre_solver_metrics: dict[str, float], | |
| post_solver_metrics: dict[str, float], | |
| pre_proposer_metrics: dict[str, float], | |
| post_proposer_metrics: dict[str, float], | |
| log_history: list[dict[str, float]], | |
| ) -> Path | None: | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| print("matplotlib is not installed, skipping plot generation.") | |
| return None | |
| DEFAULT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| plot_path = DEFAULT_OUTPUT_DIR / "debugzero_results.png" | |
| figure, axes = plt.subplots(1, 2, figsize=(10, 4)) | |
| axes[0].bar( | |
| ["solver pre", "solver post", "proposer pre", "proposer post"], | |
| [ | |
| pre_solver_metrics["pass_rate"], | |
| post_solver_metrics["pass_rate"], | |
| pre_proposer_metrics["break_rate"], | |
| post_proposer_metrics["break_rate"], | |
| ], | |
| color=["#4f81bd", "#4f81bd", "#c0504d", "#c0504d"], | |
| ) | |
| axes[0].set_ylim(0.0, 1.0) | |
| axes[0].set_title("Fixed Eval Rates") | |
| axes[0].set_ylabel("Rate") | |
| steps = [entry["step"] for entry in log_history if "step" in entry] | |
| losses = [entry["loss"] for entry in log_history if "loss" in entry] | |
| if steps and losses: | |
| axes[1].plot(steps[: len(losses)], losses, marker="o") | |
| axes[1].set_title("Training Loss") | |
| axes[1].set_xlabel("Step") | |
| axes[1].set_ylabel("Loss") | |
| else: | |
| axes[1].bar( | |
| ["solver reward pre", "solver reward post"], | |
| [ | |
| pre_solver_metrics["mean_reward"], | |
| post_solver_metrics["mean_reward"], | |
| ], | |
| color=["#9bbb59", "#9bbb59"], | |
| ) | |
| axes[1].set_title("Solver Mean Reward") | |
| figure.tight_layout() | |
| figure.savefig(plot_path) | |
| plt.close(figure) | |
| return plot_path | |
| def run_workflow(dry_run: bool = False) -> dict[str, object]: | |
| dataset, bug_bank = create_dataset() | |
| print( | |
| f"Built dataset with {len(dataset)} rows from " | |
| f"{len(bug_bank.train_samples)} training bugs and {len(bug_bank.eval_samples)} eval bugs." | |
| ) | |
| model, tokenizer = load_training_model_and_tokenizer(dry_run, dataset, bug_bank) | |
| trainer = create_trainer(model, tokenizer, dataset, dry_run) | |
| reset_reward_history() | |
| pre_solver_metrics = evaluate_solver_fixed_set(model, tokenizer, bug_bank) | |
| pre_proposer_metrics = evaluate_proposer_fixed_set(model, tokenizer) | |
| print("Pre-training solver metrics:", pre_solver_metrics) | |
| print("Pre-training proposer metrics:", pre_proposer_metrics) | |
| reset_reward_history() | |
| train_result = trainer.train() | |
| post_solver_metrics = evaluate_solver_fixed_set(trainer.model, tokenizer, bug_bank) | |
| post_proposer_metrics = evaluate_proposer_fixed_set(trainer.model, tokenizer) | |
| plot_path = save_results_plot( | |
| pre_solver_metrics, | |
| post_solver_metrics, | |
| pre_proposer_metrics, | |
| post_proposer_metrics, | |
| trainer.state.log_history, | |
| ) | |
| metrics_artifact_path = save_metrics_artifact(pre_proposer_metrics, post_proposer_metrics) | |
| results = { | |
| "train_result": train_result, | |
| "pre_solver_metrics": pre_solver_metrics, | |
| "post_solver_metrics": post_solver_metrics, | |
| "pre_proposer_metrics": pre_proposer_metrics, | |
| "post_proposer_metrics": post_proposer_metrics, | |
| "plot_path": str(plot_path) if plot_path else None, | |
| "metrics_artifact_path": str(metrics_artifact_path), | |
| "dataset_size": len(dataset), | |
| "train_bug_count": len(bug_bank.train_samples), | |
| "eval_bug_count": len(bug_bank.eval_samples), | |
| } | |
| print("Post-training solver metrics:", post_solver_metrics) | |
| print("Post-training proposer metrics:", post_proposer_metrics) | |
| if plot_path: | |
| print(f"Saved plot to {plot_path}") | |
| print(f"Saved proposer metrics to {metrics_artifact_path}") | |
| return results | |
| def main(): | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dry_run", action="store_true", help="Run a tiny local GRPO smoke test.") | |
| args = parser.parse_args() | |
| run_workflow(dry_run=args.dry_run) | |
| if __name__ == "__main__": | |
| main() | |