i2ebuddy-gpt2 / main.py
aniketnikam06's picture
Upload main.py
ee3e9cf verified
raw
history blame contribute delete
No virus
920 Bytes
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
dataset = load_dataset('i2ebuddy/website_data', split='train')
dataset = dataset.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True, max_length=512), batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=4,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=3,
report_to="none" # do not report to any service for logging
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
trainer.train()