naturalwellness-rlhf / train_reward.py
tarnava's picture
Upload folder using huggingface_hub
6e07610 verified
# train_reward.py
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset
from huggingface_hub import login, HfApi
import os
from datetime import datetime
from dotenv import load_dotenv
load_dotenv()
FEEDBACK_FILE = "feedback.json"
REWARD_MODEL_PATH = "./reward_model"
HF_TOKEN = os.getenv("HF_TOKEN")
HF_REWARD_REPO = os.getenv("HF_REWARD_REPO", "modular-ai/kantian-reward-model")
# Kantian Persona Context
KANTIAN_CONTEXT = """Kantian Adversarial Critic - Personality: duty-focused, universality-tester, moral-consistency-seeker, rights-defender.
ADVERSARIAL CRITIQUE MODE:
1. Challenge arguments systematically
2. Identify flaws and weaknesses rigorously
3. Quote exact text when critiquing
4. Attack logical fallacies directly
5. Test through adversarial analysis
Evaluates critiques based on:
- Strength of adversarial challenge
- Rigor in identifying weaknesses
- Application of Kantian principles
- Systematic argument testing
"""
if not os.path.exists(FEEDBACK_FILE):
print("No feedback data.")
exit()
with open(FEEDBACK_FILE, "r") as f:
data = json.load(f)
if len(data) < 50:
print(f"Need 50+ samples. Current: {len(data)}")
exit()
# Prepare - Format for Kantian critique training
# Include text feedback when available for richer training
print(f"Processing {len(data)} feedback samples...")
texts = []
for d in data:
prompt = d['prompt']
response = d['response']
text_feedback = d.get('text_feedback', '')
# Create training text that captures Kantian critique quality
if text_feedback:
# Include detailed feedback for more nuanced training
text = f"{KANTIAN_CONTEXT}\n\n{prompt}\n\nKantian Critique: {response}\n\nDetailed Feedback: {text_feedback}"
else:
text = f"{KANTIAN_CONTEXT}\n\n{prompt}\n\nKantian Critique: {response}"
texts.append(text)
labels = [d['reward'] for d in data] # 0 (not helpful) or 1 (helpful)
dataset = Dataset.from_dict({"text": texts, "label": labels})
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def tokenize_fn(examples):
return tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
tokenized = dataset.map(tokenize_fn, batched=True)
splits = tokenized.train_test_split(test_size=0.2)
# Model + LoRA
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=8, lora_alpha=32, lora_dropout=0.1,
bias="none",
target_modules=["query", "value"] # Specify target modules for DistilBERT
)
model = get_peft_model(model, peft_config)
args = TrainingArguments(
output_dir=REWARD_MODEL_PATH,
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
weight_decay=0.01,
load_best_model_at_end=True,
logging_dir="./logs_reward",
fp16=torch.cuda.is_available(),
report_to=None,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=splits["train"],
eval_dataset=splits["test"],
)
print("Training reward model...")
trainer.train()
trainer.save_model(REWARD_MODEL_PATH)
print(f"Reward model saved to {REWARD_MODEL_PATH}")
# Push to Hugging Face with version tag
if HF_TOKEN:
try:
login(token=HF_TOKEN)
api = HfApi()
# Create version tag based on timestamp and sample count
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
version_tag = f"v-{len(data)}-samples-{timestamp}"
print(f"Pushing reward model to Hugging Face as version: {version_tag}")
# Push model with version tag (creates new version while keeping old ones)
model.push_to_hub(
HF_REWARD_REPO,
commit_message=f"Reward model trained on {len(data)} samples - {timestamp}",
)
tokenizer.push_to_hub(HF_REWARD_REPO)
print(f"✓ Reward model pushed to {HF_REWARD_REPO} with tag: {version_tag}")
print(f" Old versions remain accessible on Hugging Face")
except Exception as e:
print(f"Warning: Could not push to Hugging Face: {e}")
else:
print("Warning: HF_TOKEN not set, skipping model upload")