|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling |
|
|
|
|
|
dataset = load_dataset("Percy3822/quiz_model") |
|
|
|
|
|
def format_for_training(example): |
|
|
|
if isinstance(example["completion"], dict): |
|
example["completion"] = str(example["completion"]) |
|
return {"text": example["prompt"] + "\n" + example["completion"]} |
|
|
|
dataset = dataset.map(format_for_training) |
|
|
|
|
|
model_name = "distilgpt2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def tokenize(batch): |
|
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=128) |
|
|
|
dataset = dataset.map(tokenize, batched=True) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
overwrite_output_dir=True, |
|
evaluation_strategy="epoch", |
|
learning_rate=5e-5, |
|
per_device_train_batch_size=2, |
|
num_train_epochs=1, |
|
save_strategy="epoch", |
|
logging_dir="./logs", |
|
logging_steps=5, |
|
push_to_hub=True, |
|
hub_model_id="Percy3822/quiz_model", |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["train"], |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
trainer.push_to_hub() |