File size: 3,478 Bytes
4ed4ad6
e7c391d
3762217
4ed4ad6
 
 
 
 
 
 
e7c391d
3762217
4ed4ad6
3762217
4ed4ad6
c61901c
 
3762217
4ed4ad6
3762217
 
 
4ed4ad6
 
 
 
 
 
 
4e2ce18
 
4ed4ad6
e7c391d
3762217
4ed4ad6
3762217
4ed4ad6
 
3762217
4ed4ad6
 
 
 
3762217
4ed4ad6
3762217
 
4ed4ad6
3762217
 
 
 
 
 
4ed4ad6
3762217
 
 
 
 
 
 
 
 
 
4ed4ad6
3762217
 
 
 
4ed4ad6
3762217
 
 
 
c61901c
4ed4ad6
c61901c
4ed4ad6
3762217
4ed4ad6
 
 
3762217
e7c391d
3762217
 
 
 
 
 
 
4ed4ad6
3762217
4ed4ad6
3762217
 
4ed4ad6
3762217
 
e7c391d
 
3762217
 
 
4ed4ad6
 
 
 
 
 
 
3762217
e5e8e5e
3762217
 
4ed4ad6
3762217
 
 
 
4ed4ad6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# app.py
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

# ---------------------------
# Step 1: Load Anthropic Dataset
# ---------------------------
print("Loading dataset...")
ds = load_dataset("Anthropic/hh-rlhf")

# ---------------------------
# Step 2: Prepare prompt-response pairs
# ---------------------------
train_data = []
for item in ds["train"]:
    text = item["chosen"]
    # Try to split into Human / Assistant
    if "Assistant:" in text:
        parts = text.split("Assistant:")
        human = parts[0].replace("Human:", "").strip()
        assistant = parts[1].strip()
        train_data.append({"input": human, "output": assistant})

print(f"Total training examples: {len(train_data)}")
print("Example:", train_data[0])

# ---------------------------
# Step 3: Load tokenizer and model
# ---------------------------
model_name = "distilgpt2"
print(f"Loading model and tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 🔧 Fix for GPT-2 padding issue
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

# ---------------------------
# Step 4: Tokenize data
# ---------------------------
def tokenize_function(example):
    return tokenizer(
        example["input"] + " " + example["output"],
        truncation=True,
        padding="max_length",
        max_length=128,
    )

tokenized_data = [tokenize_function(item) for item in train_data]

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return {k: torch.tensor(v) for k, v in self.data[idx].items()}

train_dataset = CustomDataset(tokenized_data)

# ---------------------------
# Step 5: Fine-tune model
# ---------------------------
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="./hh_rlhf_model",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    save_steps=500,
    logging_steps=50,
    save_total_limit=1,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

print("Starting fine-tuning...")
trainer.train()
print("✅ Training complete!")

# ---------------------------
# Step 6: Simple Chat Loop
# ---------------------------
conversation_history = []

def chat(user_input):
    full_input = " ".join([f"You: {u} AI: {a}" for u, a in conversation_history])
    full_input += f" You: {user_input} AI:"
    input_ids = tokenizer.encode(full_input, return_tensors="pt")
    output_ids = model.generate(
        input_ids,
        max_length=150,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,
        temperature=0.7,
    )
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    conversation_history.append((user_input, response))
    return response

print("\nAnthropic hh-rlhf chatbot ready! Type 'exit' to quit.\n")
while True:
    user_input = input("You: ")
    if user_input.lower() == "exit":
        break
    print("AI:", chat(user_input))