|
import os |
|
import time |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
|
|
|
class GptHumorTrainer: |
|
|
|
def __init__(self, silent=False) -> None: |
|
start_time = time.perf_counter() |
|
self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") |
|
self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState")) |
|
self.model.eval() |
|
if not silent: |
|
print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") |
|
|
|
def local_file_path(self, path): |
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) |
|
|
|
def train(self, train_file, epochs=3): |
|
device = torch.device("cpu") |
|
self.model.to(device) |
|
|
|
|
|
train_dataset = TextDataset( |
|
tokenizer=self.tokenizer, |
|
file_path=train_file, |
|
block_size=128, |
|
) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=self.tokenizer, |
|
mlm=False, |
|
) |
|
|
|
for epoch in range(epochs): |
|
|
|
training_args = TrainingArguments( |
|
output_dir=f"./results/epoch_{epoch+1}", |
|
overwrite_output_dir=True, |
|
num_train_epochs=3, |
|
per_device_train_batch_size=3, |
|
save_steps=-1, |
|
save_total_limit=None, |
|
prediction_loss_only=True, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
self.model.save_pretrained(self.local_file_path("SaveState")) |
|
|
|
if __name__ == "__main__": |
|
humor_trainer = GptHumorTrainer() |
|
humor_trainer.train(humor_trainer.local_file_path("TrainData.txt"), epochs=5) |
|
|