FitnessChatbot / finetune.py
MarshalAM's picture
Upload 3 files
04f76c0
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import torch
# Define your model and tokenizer
model_name = "gpt2" # You can use different GPT-2 variants like "gpt2-medium" if needed
config = GPT2Config.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, config=config)
# Move the model to the appropriate device (e.g., GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define training arguments
training_args = TrainingArguments(
output_dir="./fitnessbot", # Directory to save model checkpoints and results
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=2,
save_steps=1000,
save_total_limit=2,
)
# Load your dataset
dataset = TextDataset(
tokenizer=tokenizer,
file_path="fitness_data.txt", # Path to your dataset file
block_size=256, # Adjust this block size based on your model's token limit
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# Create a Trainer instance and fine-tune the model
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
# Fine-tune the model
trainer.train()
trainer.save_model()
trainer.save_state()
# Chat with the bot
print("Welcome to the Fitness Chatbot. Type 'exit' to end the conversation.")
while True:
user_input = input("You: ")
if user_input.lower() == 'exit':
print("Fitness Chatbot: Goodbye!")
break
# Encode the new user input, add to the chat history, and generate a response
input_ids = tokenizer.encode("You: " + user_input, return_tensors="pt").to(device)
bot_input_ids = input_ids
# Generate a response
response = model.generate(bot_input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2)
bot_response = tokenizer.decode(response[0], skip_special_tokens=True)
print("Fitness Chatbot:", bot_response)