Spaces:
Runtime error
Runtime error
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments | |
import torch | |
from datasets import load_dataset | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForSequenceClassification.from_pretrained("google/flan-t5-small").to(device) | |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") | |
dataset = load_dataset("athirdpath/DPO_Pairs-Roleplay-Alpaca-NSFW") | |
train_dataset = dataset["train"] | |
max_length = 512 # You can change this according to your needs | |
def preprocess(example): | |
# Tokenize the inputs and outputs | |
input_tokens = tokenizer(example["prompt"], truncation=True, max_length=max_length, padding="max_length", return_tensors="pt") | |
output_tokens = tokenizer(example["chosen"], truncation=True, max_length=max_length, padding="max_length", return_tensors="pt") | |
# Convert the tokens to tensors and move them to the device | |
input_ids = input_tokens["input_ids"].squeeze().to(device) | |
attention_mask = input_tokens["attention_mask"].squeeze().to(device) | |
output_ids = output_tokens["input_ids"].squeeze().to(device) | |
# Return a dictionary of tensors | |
return {"input_ids": input_ids, "attention_mask": attention_mask, "output_ids": output_ids} | |
# Apply the preprocess function to the train, validation, and test sets | |
train_dataset = train_dataset.map(preprocess, batched=True) | |
# Define the training arguments | |
training_args = TrainingArguments( | |
output_dir="output", | |
num_train_epochs=3, | |
learning_rate=5e-5, | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
evaluation_strategy="steps", # Change this to steps | |
save_strategy="steps", # Change this to steps | |
logging_dir="logs", | |
load_best_model_at_end=True, | |
) | |
# Define the trainer | |
trainer = Trainer( | |
model=model, # The model to train | |
args=training_args, # The training arguments | |
train_dataset=train_dataset, # The training dataset | |
) | |
# Train the model | |
trainer.train() | |
# Evaluate the m odel on the test set | |
trainer.evaluate(test_dataset) |