rlm-arithmetic-training / train_arithmetic_v2.py
mindchain's picture
Upload train_arithmetic_v2.py with huggingface_hub
74c1152 verified
#!/usr/bin/env python3
"""
GRPO + RLVR Training for Simple Arithmetic - v2
Task: 2-digit addition and subtraction
Base Model: Qwen/Qwen3-0.6B-Base
Improvements:
- Better reward function with debugging
- Force EOS token in generation
- Per-step evaluation
- Clear tracking metrics
"""
import os
import re
import random
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
# ============================================================================
# CONFIG
# ============================================================================
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v2"
MAX_STEPS = 20
NUM_SAMPLES = 500
EVAL_SAMPLES = 20
EVAL_EVERY = 5 # Evaluate every N steps
# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_arithmetic_samples(n_samples):
"""Generate simple arithmetic problems"""
samples = []
for _ in range(n_samples):
op = random.choice(['+', '-'])
if op == '+':
a = random.randint(10, 99)
b = random.randint(10, 99)
answer = a + b
problem = f"{a} + {b} = ?"
else:
a = random.randint(20, 99)
b = random.randint(10, a-1)
answer = a - b
problem = f"{a} - {b} = ?"
samples.append({
'prompt': f"Solve: {problem}\nAnswer:",
'answer': str(answer),
'ground_truth': str(answer), # Also provide ground_truth for GRPO
})
return samples
# ============================================================================
# REWARD FUNCTION (with debugging)
# ============================================================================
def reward_func(completions, prompts=None, **kwargs):
"""
Reward function for arithmetic with debugging.
"""
# Try multiple column names for ground truth
answers = None
for key in ['answer', 'ground_truth', 'solution', 'label']:
if key in kwargs and kwargs[key] is not None:
answers = kwargs[key]
break
if answers is None:
print("⚠️ WARNING: No ground truth found in kwargs!")
print(f" Available keys: {list(kwargs.keys())}")
return [0.0] * len(completions)
rewards = []
debug_samples = min(2, len(completions)) # Debug first 2 samples
for i, (completion, truth) in enumerate(zip(completions, answers)):
# Handle list format (conversational)
if isinstance(completion, list):
text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
else:
text = str(completion)
# Extract the last number
numbers = re.findall(r'-?\d+\.?\d*', text)
if numbers:
predicted = numbers[-1].strip()
else:
predicted = ""
# Exact match reward
is_correct = predicted == str(truth).strip()
rewards.append(1.0 if is_correct else 0.0)
# Debug first few samples
if i < debug_samples:
status = "βœ…" if is_correct else "❌"
print(f" [{i+1}] {status} Truth={truth} | Pred={predicted} | Text={text[:80]}...")
return rewards
# ============================================================================
# EVALUATION
# ============================================================================
def evaluate_model(model, tokenizer, n_samples=EVAL_SAMPLES, step=0):
"""Evaluate model performance"""
print(f"\n{'='*70}")
print(f"πŸ“Š EVALUATION @ Step {step}")
print(f"{'='*70}")
test_samples = generate_arithmetic_samples(n_samples)
correct = 0
model.eval()
with torch.no_grad():
for i, sample in enumerate(test_samples):
inputs = tokenizer(sample['prompt'], return_tensors='pt')
if hasattr(model, 'device') and model.device is not None:
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=30,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
input_ids = inputs.get('input_ids')
if input_ids is not None and hasattr(input_ids, 'shape'):
response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
else:
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract answer
numbers = re.findall(r'-?\d+\.?\d*', response)
predicted = numbers[-1].strip() if numbers else ""
truth = sample['answer'].strip()
is_correct = predicted == truth
if is_correct:
correct += 1
status = "βœ…" if is_correct else "❌"
print(f"[{i+1}] {status} {truth} | Pred: {predicted} | {response[:40]}...")
accuracy = correct / n_samples * 100
print(f"\nπŸ“Š Accuracy: {accuracy:.1f}% ({correct}/{n_samples})")
print(f"{'='*70}\n")
model.train()
return accuracy
# ============================================================================
# CALLBACK FOR PER-STEP EVAL
# ============================================================================
from transformers import TrainerCallback
class EvalCallback(TrainerCallback):
def __init__(self, model, tokenizer, eval_every=EVAL_EVERY):
self.model = model
self.tokenizer = tokenizer
self.eval_every = eval_every
self.accuracies = []
def on_step_end(self, args, state, control, **kwargs):
if state.global_step > 0 and state.global_step % self.eval_every == 0:
acc = evaluate_model(self.model, self.tokenizer, step=state.global_step)
self.accuracies.append((state.global_step, acc))
# Print summary
print(f"\nπŸ“ˆ Progress Summary:")
for step, accuracy in self.accuracies:
print(f" Step {step}: {accuracy:.1f}%")
print()
# ============================================================================
# MAIN TRAINING
# ============================================================================
def main():
print("="*70)
print("πŸ”’ GRPO + RLVR Arithmetic Training - v2")
print("="*70)
print(f"Base Model: {BASE_MODEL}")
print(f"Output: {OUTPUT_MODEL}")
print(f"Steps: {MAX_STEPS}")
print(f"Eval every: {EVAL_EVERY} steps")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print("="*70 + "\n")
# Load model and tokenizer
print("πŸ“¦ Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f" Set pad_token to eos_token: {tokenizer.eos_token}")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
# Resize embeddings if needed
model.resize_token_embeddings(len(tokenizer))
# Initial evaluation
initial_acc = evaluate_model(model, tokenizer, step=0)
# Generate training data
print("πŸ“Š Generating training data...")
train_samples = generate_arithmetic_samples(NUM_SAMPLES)
train_dataset = Dataset.from_list(train_samples)
print(f"βœ… {len(train_dataset)} training samples\n")
# GRPO Config
is_cpu = not torch.cuda.is_available()
training_args = GRPOConfig(
output_dir="./outputs",
max_steps=MAX_STEPS,
per_device_train_batch_size=2,
num_generations=2,
learning_rate=2e-4,
beta=0.0, # No KL penalty for arithmetic
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
fp16=False,
gradient_checkpointing=not is_cpu,
optim="adamw_torch" if is_cpu else "adamw_8bit",
logging_steps=1,
save_steps=MAX_STEPS,
push_to_hub=False,
report_to="none",
)
# Eval callback
eval_callback = EvalCallback(model, tokenizer, eval_every=EVAL_EVERY)
print("πŸš€ Starting GRPO Training...")
print(f"Initial accuracy: {initial_acc:.1f}%\n")
# Train
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
reward_funcs=[reward_func],
callbacks=[eval_callback],
)
trainer.train()
# Final evaluation
final_acc = evaluate_model(model, tokenizer, step=MAX_STEPS)
# Summary
print("\n" + "="*70)
print("πŸ“Š FINAL RESULTS")
print("="*70)
print(f"Initial Accuracy: {initial_acc:.1f}%")
print(f"Final Accuracy: {final_acc:.1f}%")
print(f"Improvement: {final_acc - initial_acc:+.1f}%")
print()
print("πŸ“ˆ Training Progress:")
for step, acc in eval_callback.accuracies:
print(f" Step {step}: {acc:.1f}%")
print("="*70)
# Save to Hub
print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
trainer.model.push_to_hub(OUTPUT_MODEL)
tokenizer.push_to_hub(OUTPUT_MODEL)
print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
if __name__ == "__main__":
main()