Zeeshan42's picture
Update app.py
e85a115 verified
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import Dataset
from groq import Groq
import os
# Initialize Groq client with your API key
client = Groq(api_key="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")
# Paths to your books (assuming files are in the root directory of the repo)
book_paths = {
"DSM": "./DSM5.pdf",
"Personality": "./TheoriesofPersonality.pdf",
"SearchForMeaning": "./MansSearchForMeaning.pdf"
}
# Function to load and preprocess the data from books
def load_data(paths):
data = []
for title, path in paths.items():
print(f"Attempting to load file for {title} from path: {path}")
try:
with open(path, "r", encoding="utf-8", errors='ignore') as file:
text = file.read()
paragraphs = text.split("\n\n") # Split by paragraphs (adjust as needed)
for paragraph in paragraphs:
if paragraph.strip(): # Skip empty paragraphs
data.append({"text": paragraph.strip()})
except FileNotFoundError:
print(f"Error: File for {title} not found at path {path}")
return Dataset.from_list(data)
# Load and preprocess dataset for fine-tuning
dataset = load_data(book_paths)
# Load pretrained model and tokenizer from Hugging Face
model_name = "gpt2" # Replace with a larger model if needed and feasible
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set the pad_token to be the same as eos_token (fix for missing padding token)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
# Tokenize data and create labels (shifted input for causal language modeling)
def tokenize_function(examples):
# Tokenize the input text
encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
# Create labels by shifting the input ids by one position (for causal LM)
labels = encodings["input_ids"].copy()
labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]
# Return the encodings with labels
encodings["labels"] = labels
return encodings
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Split dataset into train and eval (explicit split for better validation)
train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]
# Define training arguments
training_args = TrainingArguments(
output_dir="./results", # Output directory for model and logs
eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy
learning_rate=2e-5, # Learning rate
per_device_train_batch_size=8, # Batch size for training
per_device_eval_batch_size=8, # Batch size for evaluation
num_train_epochs=3, # Number of training epochs
weight_decay=0.01, # Weight decay for regularization
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset, # Pass eval dataset for evaluation
tokenizer=tokenizer, # Provide tokenizer for model inference
)
# Fine-tune the model
trainer.train()
# Save the model after fine-tuning
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")
# Step 4: Define response function with emergency keyword check
def get_response(user_input):
# Check for emergency/distress keywords
distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"]
is_distress = any(word in user_input.lower() for word in distress_keywords)
# Use Groq API for generating a response
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": user_input}],
model="llama3-8b-8192", # Or replace with another model
)
response = chat_completion.choices[0].message.content
# Append emergency message if distress keywords are detected
if is_distress:
response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]."
return response
# Step 5: Set up Gradio Interface
import gradio as gr
def chatbot_interface(input_text):
return get_response(input_text)
# Launch the Gradio app
gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch()