sql_env / scripts /test_training_local.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Local test for GRPO training with SQLEnvTRL.
Usage:
docker build -f Dockerfile.test -t sqlenv-test .
docker run sqlenv-test
docker run sqlenv-test python scripts/test_training_local.py \
--config configs/colab_l4.json
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
root = Path(__file__).parent.parent
sys.path.insert(0, str(root))
def load_config(path: str) -> dict:
with open(path) as f:
return json.load(f)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="configs/test_cpu.json",
help="Training config JSON",
)
args = parser.parse_args()
cfg = load_config(args.config)
print(f"Config: {args.config}")
print(json.dumps(cfg, indent=2))
import transformers
import trl
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from sql_env.training.trl_adapter import (
SQLEnvTRL,
sql_env_reward_func,
)
print(f"\nTRL: {trl.__version__}, Transformers: {transformers.__version__}")
# 1. Configure environment
SQLEnvTRL._configure(
questions_path=cfg["questions_path"],
db_dir=cfg["db_dir"],
step_budget=cfg["step_budget"],
)
env = SQLEnvTRL()
obs = env.reset()
print("\n--- Environment smoke test ---")
print(f"Reset: {obs}")
r = env.describe(table_name="employee")
print(f"Describe: {r[:80]}")
r = env.query(sql="SELECT COUNT(*) FROM employee")
print(f"Query: {r}")
r = env.answer(value="10")
print(f"Answer: {r}")
print(f"Total reward: {env.reward:.4f}")
# 2. Dataset
enable_thinking = cfg.get("enable_thinking", False)
system_prompt_base = (
"You answer questions about a SQL database. "
"Use ONLY the provided tools.\n\n"
"Strategy:\n"
"1. Call describe(table_name=...) to see columns\n"
"2. Call query(sql=...) to run SELECT queries\n"
"3. Call answer(value=...) to submit your answer"
)
system_prompt = (
system_prompt_base if enable_thinking else "/no_think\n" + system_prompt_base
)
questions = [
"How many employees are there?",
"What are the names of all shops?",
"Find the total number of concerts.",
"List all singer names.",
]
prompt_msgs = [
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": q},
]
for q in questions
]
size = cfg.get("dataset_size", len(prompt_msgs))
repeated = (prompt_msgs * ((size // len(prompt_msgs)) + 1))[:size]
repeated_q = (questions * ((size // len(questions)) + 1))[:size]
dataset = Dataset.from_dict({"prompt": repeated, "question_text": repeated_q})
# 3. Trainer config
print("\n--- Building trainer ---")
grpo_kwargs = {
"output_dir": cfg["output_dir"],
"per_device_train_batch_size": cfg["per_device_train_batch_size"],
"num_generations": cfg["num_generations"],
"num_train_epochs": cfg["num_train_epochs"],
"max_completion_length": cfg["max_completion_length"],
"logging_steps": cfg["logging_steps"],
"log_completions": True,
"num_completions_to_print": cfg.get("num_completions_to_print", 2),
"remove_unused_columns": False,
}
if cfg.get("max_steps"):
grpo_kwargs["max_steps"] = cfg["max_steps"]
grpo_kwargs["chat_template_kwargs"] = {
"enable_thinking": enable_thinking,
}
precision = cfg.get("precision", "fp32")
if precision == "bf16":
grpo_kwargs.update(bf16=True, fp16=False)
elif precision == "fp16":
grpo_kwargs.update(bf16=False, fp16=True)
else:
grpo_kwargs.update(bf16=False, fp16=False)
trainer = GRPOTrainer(
model=cfg["model_name"],
reward_funcs=sql_env_reward_func,
train_dataset=dataset,
environment_factory=SQLEnvTRL,
args=GRPOConfig(**grpo_kwargs),
)
# 4. Train
print(f"\n--- Training ({cfg.get('max_steps', 'all')} steps) ---")
trainer.train()
# 5. Results
print("\n--- Results ---")
for entry in trainer.state.log_history:
step = entry.get("step")
loss = entry.get("loss")
if loss is None:
continue
reward = entry.get("reward", 0)
reward_std = entry.get("reward_std", 0)
tools_freq = entry.get("tools/call_frequency", 0)
clipped = entry.get("completions/clipped_ratio", 0)
mean_len = entry.get("completions/mean_length", 0)
print(
f"Step {step:>3}: "
f"loss={loss:.4f} "
f"reward={reward:.4f} +/-{reward_std:.4f} "
f"tools={tools_freq:.2f} "
f"clipped={clipped:.0%} "
f"len={mean_len:.0f}"
)
losses = [e["loss"] for e in trainer.state.log_history if "loss" in e]
rewards = [e.get("reward", 0) for e in trainer.state.log_history if "loss" in e]
print(f"\nLoss: {losses}")
print(f"Reward: {rewards}")
if losses and any(v != 0.0 for v in losses):
print("\nSUCCESS: Non-zero training loss")
else:
print("\nFAILED: All losses zero")
if rewards and any(v != 0.0 for v in rewards):
print("SUCCESS: Non-zero rewards")
else:
print("FAILED: All rewards zero")
if __name__ == "__main__":
main()