from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments from datasets import Dataset from groq import Groq import os # Initialize Groq client with your API key client = Groq(api_key="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") # Paths to your books (assuming files are in the root directory of the repo) book_paths = { "DSM": "./DSM5.pdf", "Personality": "./TheoriesofPersonality.pdf", "SearchForMeaning": "./MansSearchForMeaning.pdf" } # Function to load and preprocess the data from books def load_data(paths): data = [] for title, path in paths.items(): print(f"Attempting to load file for {title} from path: {path}") try: with open(path, "r", encoding="utf-8", errors='ignore') as file: text = file.read() paragraphs = text.split("\n\n") # Split by paragraphs (adjust as needed) for paragraph in paragraphs: if paragraph.strip(): # Skip empty paragraphs data.append({"text": paragraph.strip()}) except FileNotFoundError: print(f"Error: File for {title} not found at path {path}") return Dataset.from_list(data) # Load and preprocess dataset for fine-tuning dataset = load_data(book_paths) # Load pretrained model and tokenizer from Hugging Face model_name = "gpt2" # Replace with a larger model if needed and feasible tokenizer = AutoTokenizer.from_pretrained(model_name) # Set the pad_token to be the same as eos_token (fix for missing padding token) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name) # Tokenize data and create labels (shifted input for causal language modeling) def tokenize_function(examples): # Tokenize the input text encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512) # Create labels by shifting the input ids by one position (for causal LM) labels = encodings["input_ids"].copy() labels = [l if l != tokenizer.pad_token_id else -100 for l in labels] # Return the encodings with labels encodings["labels"] = labels return encodings tokenized_dataset = dataset.map(tokenize_function, batched=True) # Split dataset into train and eval (explicit split for better validation) train_test_split = tokenized_dataset.train_test_split(test_size=0.1) train_dataset = train_test_split["train"] eval_dataset = train_test_split["test"] # Define training arguments training_args = TrainingArguments( output_dir="./results", # Output directory for model and logs eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy learning_rate=2e-5, # Learning rate per_device_train_batch_size=8, # Batch size for training per_device_eval_batch_size=8, # Batch size for evaluation num_train_epochs=3, # Number of training epochs weight_decay=0.01, # Weight decay for regularization ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, # Pass eval dataset for evaluation tokenizer=tokenizer, # Provide tokenizer for model inference ) # Fine-tune the model trainer.train() # Save the model after fine-tuning model.save_pretrained("./fine_tuned_model") tokenizer.save_pretrained("./fine_tuned_model") # Step 4: Define response function with emergency keyword check def get_response(user_input): # Check for emergency/distress keywords distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"] is_distress = any(word in user_input.lower() for word in distress_keywords) # Use Groq API for generating a response chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": user_input}], model="llama3-8b-8192", # Or replace with another model ) response = chat_completion.choices[0].message.content # Append emergency message if distress keywords are detected if is_distress: response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]." return response # Step 5: Set up Gradio Interface import gradio as gr def chatbot_interface(input_text): return get_response(input_text) # Launch the Gradio app gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch()