|
|
|
|
|
import json |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer |
|
|
from peft import LoraConfig, get_peft_model, TaskType |
|
|
from datasets import Dataset |
|
|
from huggingface_hub import login, HfApi |
|
|
import os |
|
|
from datetime import datetime |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
FEEDBACK_FILE = "feedback.json" |
|
|
REWARD_MODEL_PATH = "./reward_model" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
HF_REWARD_REPO = os.getenv("HF_REWARD_REPO", "modular-ai/kantian-reward-model") |
|
|
|
|
|
|
|
|
KANTIAN_CONTEXT = """Kantian Adversarial Critic - Personality: duty-focused, universality-tester, moral-consistency-seeker, rights-defender. |
|
|
|
|
|
ADVERSARIAL CRITIQUE MODE: |
|
|
1. Challenge arguments systematically |
|
|
2. Identify flaws and weaknesses rigorously |
|
|
3. Quote exact text when critiquing |
|
|
4. Attack logical fallacies directly |
|
|
5. Test through adversarial analysis |
|
|
|
|
|
Evaluates critiques based on: |
|
|
- Strength of adversarial challenge |
|
|
- Rigor in identifying weaknesses |
|
|
- Application of Kantian principles |
|
|
- Systematic argument testing |
|
|
""" |
|
|
|
|
|
if not os.path.exists(FEEDBACK_FILE): |
|
|
print("No feedback data.") |
|
|
exit() |
|
|
|
|
|
with open(FEEDBACK_FILE, "r") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if len(data) < 50: |
|
|
print(f"Need 50+ samples. Current: {len(data)}") |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
print(f"Processing {len(data)} feedback samples...") |
|
|
texts = [] |
|
|
for d in data: |
|
|
prompt = d['prompt'] |
|
|
response = d['response'] |
|
|
text_feedback = d.get('text_feedback', '') |
|
|
|
|
|
|
|
|
if text_feedback: |
|
|
|
|
|
text = f"{KANTIAN_CONTEXT}\n\n{prompt}\n\nKantian Critique: {response}\n\nDetailed Feedback: {text_feedback}" |
|
|
else: |
|
|
text = f"{KANTIAN_CONTEXT}\n\n{prompt}\n\nKantian Critique: {response}" |
|
|
texts.append(text) |
|
|
|
|
|
labels = [d['reward'] for d in data] |
|
|
|
|
|
dataset = Dataset.from_dict({"text": texts, "label": labels}) |
|
|
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") |
|
|
|
|
|
def tokenize_fn(examples): |
|
|
return tokenizer(examples["text"], truncation=True, padding=True, max_length=512) |
|
|
|
|
|
tokenized = dataset.map(tokenize_fn, batched=True) |
|
|
splits = tokenized.train_test_split(test_size=0.2) |
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) |
|
|
peft_config = LoraConfig( |
|
|
task_type=TaskType.SEQ_CLS, |
|
|
r=8, lora_alpha=32, lora_dropout=0.1, |
|
|
bias="none", |
|
|
target_modules=["query", "value"] |
|
|
) |
|
|
model = get_peft_model(model, peft_config) |
|
|
|
|
|
args = TrainingArguments( |
|
|
output_dir=REWARD_MODEL_PATH, |
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=8, |
|
|
per_device_eval_batch_size=8, |
|
|
eval_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
learning_rate=5e-5, |
|
|
weight_decay=0.01, |
|
|
load_best_model_at_end=True, |
|
|
logging_dir="./logs_reward", |
|
|
fp16=torch.cuda.is_available(), |
|
|
report_to=None, |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=args, |
|
|
train_dataset=splits["train"], |
|
|
eval_dataset=splits["test"], |
|
|
) |
|
|
|
|
|
print("Training reward model...") |
|
|
trainer.train() |
|
|
trainer.save_model(REWARD_MODEL_PATH) |
|
|
print(f"Reward model saved to {REWARD_MODEL_PATH}") |
|
|
|
|
|
|
|
|
if HF_TOKEN: |
|
|
try: |
|
|
login(token=HF_TOKEN) |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
|
|
version_tag = f"v-{len(data)}-samples-{timestamp}" |
|
|
|
|
|
print(f"Pushing reward model to Hugging Face as version: {version_tag}") |
|
|
|
|
|
|
|
|
model.push_to_hub( |
|
|
HF_REWARD_REPO, |
|
|
commit_message=f"Reward model trained on {len(data)} samples - {timestamp}", |
|
|
) |
|
|
tokenizer.push_to_hub(HF_REWARD_REPO) |
|
|
|
|
|
print(f"✓ Reward model pushed to {HF_REWARD_REPO} with tag: {version_tag}") |
|
|
print(f" Old versions remain accessible on Hugging Face") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not push to Hugging Face: {e}") |
|
|
else: |
|
|
print("Warning: HF_TOKEN not set, skipping model upload") |