from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config from transformers import TextDataset, DataCollatorForLanguageModeling from transformers import Trainer, TrainingArguments import torch # Define your model and tokenizer model_name = "gpt2" # You can use different GPT-2 variants like "gpt2-medium" if needed config = GPT2Config.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name, config=config) # Move the model to the appropriate device (e.g., GPU if available) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Define training arguments training_args = TrainingArguments( output_dir="./fitnessbot", # Directory to save model checkpoints and results overwrite_output_dir=True, num_train_epochs=3, per_device_train_batch_size=2, save_steps=1000, save_total_limit=2, ) # Load your dataset dataset = TextDataset( tokenizer=tokenizer, file_path="fitness_data.txt", # Path to your dataset file block_size=256, # Adjust this block size based on your model's token limit ) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, ) # Create a Trainer instance and fine-tune the model trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=dataset, ) # Fine-tune the model trainer.train() trainer.save_model() trainer.save_state() # Chat with the bot print("Welcome to the Fitness Chatbot. Type 'exit' to end the conversation.") while True: user_input = input("You: ") if user_input.lower() == 'exit': print("Fitness Chatbot: Goodbye!") break # Encode the new user input, add to the chat history, and generate a response input_ids = tokenizer.encode("You: " + user_input, return_tensors="pt").to(device) bot_input_ids = input_ids # Generate a response response = model.generate(bot_input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2) bot_response = tokenizer.decode(response[0], skip_special_tokens=True) print("Fitness Chatbot:", bot_response)