BotX / app.py
BotXT's picture
Update app.py
4ed4ad6 verified
# app.py
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
# ---------------------------
# Step 1: Load Anthropic Dataset
# ---------------------------
print("Loading dataset...")
ds = load_dataset("Anthropic/hh-rlhf")
# ---------------------------
# Step 2: Prepare prompt-response pairs
# ---------------------------
train_data = []
for item in ds["train"]:
text = item["chosen"]
# Try to split into Human / Assistant
if "Assistant:" in text:
parts = text.split("Assistant:")
human = parts[0].replace("Human:", "").strip()
assistant = parts[1].strip()
train_data.append({"input": human, "output": assistant})
print(f"Total training examples: {len(train_data)}")
print("Example:", train_data[0])
# ---------------------------
# Step 3: Load tokenizer and model
# ---------------------------
model_name = "distilgpt2"
print(f"Loading model and tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 🔧 Fix for GPT-2 padding issue
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
# ---------------------------
# Step 4: Tokenize data
# ---------------------------
def tokenize_function(example):
return tokenizer(
example["input"] + " " + example["output"],
truncation=True,
padding="max_length",
max_length=128,
)
tokenized_data = [tokenize_function(item) for item in train_data]
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return {k: torch.tensor(v) for k, v in self.data[idx].items()}
train_dataset = CustomDataset(tokenized_data)
# ---------------------------
# Step 5: Fine-tune model
# ---------------------------
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir="./hh_rlhf_model",
num_train_epochs=1,
per_device_train_batch_size=2,
save_steps=500,
logging_steps=50,
save_total_limit=1,
fp16=torch.cuda.is_available(),
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
print("Starting fine-tuning...")
trainer.train()
print("✅ Training complete!")
# ---------------------------
# Step 6: Simple Chat Loop
# ---------------------------
conversation_history = []
def chat(user_input):
full_input = " ".join([f"You: {u} AI: {a}" for u, a in conversation_history])
full_input += f" You: {user_input} AI:"
input_ids = tokenizer.encode(full_input, return_tensors="pt")
output_ids = model.generate(
input_ids,
max_length=150,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
temperature=0.7,
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
conversation_history.append((user_input, response))
return response
print("\nAnthropic hh-rlhf chatbot ready! Type 'exit' to quit.\n")
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
break
print("AI:", chat(user_input))