CivicAI / scripts /train_ppo.py
mahammadaftab's picture
Final updated
6298125
"""
CivicAI TRL PPO Training Script β€” scripts/train_ppo.py
=======================================================
Full training pipeline using HuggingFace TRL.
LLM (GPT-2) receives society state as text β†’ outputs JSON action.
PPO optimises the LLM against the CivicAI environment reward.
"""
from __future__ import annotations
import os, sys, json, random
import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from civicai.environment import CivicAIEnv
from civicai.models import Action, SubsidyPolicy
from civicai.reward import get_named_scores, compute_reward
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_NAME = "gpt2" # swap for "meta-llama/Llama-3.2-1B" on Colab A100
TASK_ID = "stabilize_economy"
N_EPISODES = 20 # episodes to train
STEPS_EP = 50 # max steps per episode
BATCH_SIZE = 1
LR = 1.41e-5
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DARK, PANEL, GRID = "#0f172a", "#1e293b", "#334155"
# ── Prompt / Parser ───────────────────────────────────────────────────────────
def obs_to_prompt(obs: dict) -> str:
return (
f"You are a policy advisor. State: Turn={obs['turn']}, "
f"GDP=${obs['gdp']:.0f}B, Inflation={obs['inflation']:.1%}, "
f"Employment={obs['employment_rate']:.1%}, "
f"Satisfaction={obs['public_satisfaction']:.1%}, "
f"Health={obs['health_index']:.1%}, Crime={obs['crime_rate']:.1%}. "
f"Output JSON: {{\"tax_rate\":0.0-1.0,\"healthcare_budget\":0.0-1.0,"
f"\"education_budget\":0.0-1.0,\"police_budget\":0.0-1.0,"
f"\"subsidy_policy\":\"none|agriculture|industry|technology\"}} Action:"
)
def parse_action(text: str) -> Action:
try:
s, e = text.find("{"), text.rfind("}")
if s != -1 and e != -1:
d = json.loads(text[s:e+1])
return Action(
tax_rate=max(0.0, min(1.0, float(d.get("tax_rate", 0.25)))),
healthcare_budget=max(0.0, min(1.0, float(d.get("healthcare_budget", 0.20)))),
education_budget=max(0.0, min(1.0, float(d.get("education_budget", 0.15)))),
police_budget=max(0.0, min(1.0, float(d.get("police_budget", 0.10)))),
subsidy_policy=SubsidyPolicy(d.get("subsidy_policy", "none")),
)
except Exception:
pass
return Action(
tax_rate=random.uniform(0.2, 0.4),
healthcare_budget=random.uniform(0.1, 0.3),
education_budget=random.uniform(0.05, 0.2),
police_budget=random.uniform(0.05, 0.15),
)
# ── Random Baseline ───────────────────────────────────────────────────────────
def run_random_baseline(n: int = 5) -> float:
rewards = []
env = CivicAIEnv()
for seed in range(n):
rng = random.Random(seed)
env.reset(task_id=TASK_ID, seed=seed)
ep = []
for _ in range(STEPS_EP):
a = Action(
tax_rate=rng.uniform(0.15, 0.5),
healthcare_budget=rng.uniform(0.08, 0.35),
education_budget=rng.uniform(0.05, 0.25),
police_budget=rng.uniform(0.03, 0.18),
)
_, r, done, _ = env.step(a)
ep.append(r)
if done:
break
rewards.append(float(np.mean(ep)))
return float(np.mean(rewards))
# ── Main Training ─────────────────────────────────────────────────────────────
def train_ppo():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[CivicAI] TRL PPO Training | model={MODEL_NAME} device={device}")
# Models
config = PPOConfig(
model_name=MODEL_NAME,
learning_rate=LR,
batch_size=BATCH_SIZE,
mini_batch_size=1,
gradient_accumulation_steps=1,
log_with=None,
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(device)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
ppo = PPOTrainer(config, model, ref_model, tokenizer)
gen_kwargs = dict(
max_new_tokens=80, do_sample=True, top_k=50, top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
)
env = CivicAIEnv()
# Baseline
print("[CivicAI] Computing random baseline...")
baseline_avg = run_random_baseline(5)
print(f" Random baseline avg reward: {baseline_avg:.4f}")
# Training
episode_rewards, episode_components = [], []
print(f"[CivicAI] Training for {N_EPISODES} episodes...")
for ep in range(N_EPISODES):
obs = env.reset(task_id=TASK_ID, seed=ep)
ep_rewards, ep_comp = [], []
for step in tqdm(range(STEPS_EP), desc=f"Ep {ep+1}/{N_EPISODES}", leave=False):
prompt = obs_to_prompt(obs.model_dump())
query = tokenizer.encode(prompt, return_tensors="pt").to(device)[0]
response = ppo.generate(query.unsqueeze(0), **gen_kwargs)
response_ids = response[0][len(query):]
text = tokenizer.decode(response_ids, skip_special_tokens=True)
action = parse_action(text)
obs, reward, done, info = env.step(action)
# Named component scores
state = env.state()
robj = compute_reward(state, action)
ep_comp.append(get_named_scores(robj))
reward_t = torch.tensor([reward], dtype=torch.float).to(device)
ppo.step([query], [response_ids], [reward_t])
ep_rewards.append(reward)
if done:
break
avg_r = float(np.mean(ep_rewards))
episode_rewards.append(avg_r)
episode_components.append({
k: round(float(np.mean([c[k] for c in ep_comp])), 4)
for k in ep_comp[0]
})
print(f" Ep {ep+1:2d}: avg_reward={avg_r:.4f} "
+ " ".join(f"{k}={v:.3f}" for k, v in episode_components[-1].items()))
# ── Save model ────────────────────────────────────────────────────────────
os.makedirs("assets", exist_ok=True)
model.save_pretrained("assets/civicai_ppo_model")
tokenizer.save_pretrained("assets/civicai_ppo_model")
print("\n Model saved to assets/civicai_ppo_model/")
# ── Save JSON results ─────────────────────────────────────────────────────
results = {
"baseline_avg": baseline_avg,
"episode_rewards": episode_rewards,
"episode_components": episode_components,
"final_avg": float(np.mean(episode_rewards[-5:])),
"improvement": float(np.mean(episode_rewards[-5:])) - baseline_avg,
}
with open("assets/training_results.json", "w") as f:
json.dump(results, f, indent=2)
# ── Plots ─────────────────────────────────────────────────────────────────
_plot_training_curve(episode_rewards, baseline_avg)
_plot_component_breakdown(episode_components)
print("\n[CivicAI] Training complete.")
print(f" Baseline avg: {baseline_avg:.4f}")
print(f" Final 5-ep avg: {results['final_avg']:.4f}")
print(f" Improvement: {results['improvement']:+.4f}")
return results
def _plot_training_curve(rewards: list[float], baseline: float) -> None:
smooth = np.convolve(rewards, np.ones(3)/3, mode="valid")
fig, ax = plt.subplots(figsize=(10, 5))
fig.patch.set_facecolor(DARK); ax.set_facecolor(PANEL)
ax.plot(rewards, color="#06b6d4", alpha=0.4, linewidth=1)
ax.plot(range(len(smooth)), smooth, color="#06b6d4", linewidth=2.5,
label=f"PPO Agent (final={rewards[-1]:.3f})")
ax.axhline(baseline, color="#ef4444", linestyle="--", linewidth=1.8,
label=f"Random Baseline ({baseline:.3f})")
ax.fill_between(range(len(smooth)), smooth, baseline,
where=[s > baseline for s in smooth],
alpha=0.15, color="#06b6d4", label="Improvement over baseline")
ax.set_ylim(0, 1.05)
ax.set_xlabel("Episode", color="#94a3b8"); ax.set_ylabel("Avg Step Reward", color="#94a3b8")
ax.set_title("CivicAI TRL PPO β€” Training Curve", color="#e2e8f0", fontsize=14, fontweight="bold")
ax.tick_params(colors="#94a3b8")
for sp in ax.spines.values(): sp.set_edgecolor(GRID)
ax.grid(axis="y", color=GRID, linewidth=0.5, linestyle="--")
ax.legend(facecolor=PANEL, edgecolor=GRID, labelcolor="#e2e8f0")
plt.tight_layout()
plt.savefig("assets/reward_curve.png", dpi=150, facecolor=DARK)
plt.close()
print(" Saved: assets/reward_curve.png")
def _plot_component_breakdown(components: list[dict]) -> None:
keys = ["economic_score", "health_score", "satisfaction_score", "crime_score"]
colors = ["#f59e0b", "#10b981", "#a78bfa", "#f97316"]
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.patch.set_facecolor(DARK)
fig.suptitle("Named Reward Components Over Training", color="#e2e8f0",
fontsize=13, fontweight="bold")
for ax, key, col in zip(axes, keys, colors):
vals = [c[key] for c in components]
ax.set_facecolor(PANEL)
ax.plot(vals, color=col, linewidth=2)
ax.fill_between(range(len(vals)), vals, alpha=0.15, color=col)
ax.set_ylim(0, 1.05)
ax.set_title(key.replace("_score", "").capitalize(), color="#e2e8f0", fontsize=11)
ax.tick_params(colors="#94a3b8", labelsize=8)
for sp in ax.spines.values(): sp.set_edgecolor(GRID)
ax.grid(color=GRID, linewidth=0.4, linestyle="--")
plt.tight_layout()
plt.savefig("assets/component_scores.png", dpi=150, facecolor=DARK)
plt.close()
print(" Saved: assets/component_scores.png")
if __name__ == "__main__":
train_ppo()