HumorGPT / __init__.py
TheAutonomous's picture
Upload 4 files
0f9b91a
import os
import time
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
class GptHumorTrainer:
def __init__(self, silent=False) -> None:
start_time = time.perf_counter()
self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState"))
self.model.eval()
if not silent:
print(f"Model Loading Took {time.perf_counter()-start_time} Seconds")
def local_file_path(self, path):
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
def train(self, train_file, epochs=3):
device = torch.device("cpu")
self.model.to(device)
# Prepare the dataset
train_dataset = TextDataset(
tokenizer=self.tokenizer,
file_path=train_file,
block_size=128,
)
# We use a special data collator for language modeling tasks
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False,
)
for epoch in range(epochs):
# Define the training arguments for each epoch
training_args = TrainingArguments(
output_dir=f"./results/epoch_{epoch+1}", # The output directory for this epoch
overwrite_output_dir=True, # Overwrite the content of the output directory
num_train_epochs=3, # Train for 1 epoch at a time
per_device_train_batch_size=3, # Batch size for training
save_steps=-1, # Save model after each epoch
save_total_limit=None, # No limit on the total amount of checkpoints
prediction_loss_only=True, # Focus on the prediction loss only
)
# Initialize the Trainer
trainer = Trainer(
model=self.model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
# Train the model for one epoch
trainer.train()
# Save the model after each epoch
self.model.save_pretrained(self.local_file_path("SaveState"))
if __name__ == "__main__":
humor_trainer = GptHumorTrainer()
humor_trainer.train(humor_trainer.local_file_path("TrainData.txt"), epochs=5) # Replace with the path to your training file