Zeeshan42's picture
Update app.py
41ac1e9 verified
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
from groq import Groq
import os
# Initialize Groq client with your API key
client = Groq(api_key="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")
# Book names (replace with your uploaded book names on Hugging Face)
book_names = {
"DSM": "Diagnostic_and_statistical_manual_of_mental_disorders_DSM5",
"Personality": "Theories_of_Personality_10",
"SearchForMeaning": "Mans_Search_For_Meaning"
}
# Function to load and preprocess the data from books (now using Hugging Face datasets)
def load_data(book_names):
data = []
for title, book_name in book_names.items():
# Load dataset from Hugging Face using the book name
# The dataset should be in the form of a text dataset or you should have pre-uploaded datasets
# Example: Assuming the datasets are pre-uploaded on Hugging Face and stored as text files
try:
dataset = load_dataset(book_name) # Try to load dataset by name
text = dataset['train']['text'] # Adjust depending on dataset structure
paragraphs = text.split("\n\n") # Split by paragraphs
for paragraph in paragraphs:
if paragraph.strip(): # Skip empty paragraphs
data.append({"text": paragraph.strip()})
except Exception as e:
print(f"Error loading dataset for {book_name}: {e}")
continue
return Dataset.from_list(data)
# Load and preprocess dataset for fine-tuning
dataset = load_data(book_names)
# Load pretrained model and tokenizer from Hugging Face
model_name = "gpt2" # Replace with a larger model if needed and feasible
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set the pad_token to be the same as eos_token (fix for missing padding token)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
# Tokenize data and create labels (shifted input for causal language modeling)
def tokenize_function(examples):
# Tokenize the input text
encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
# Create labels by shifting the input ids by one position (for causal LM)
labels = encodings["input_ids"].copy()
labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]
# Return the encodings with labels
encodings["labels"] = labels
return encodings
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Split dataset into train and eval (explicit split for better validation)
train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]
# Define training arguments
training_args = TrainingArguments(
output_dir="./results", # Output directory for model and logs
eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy
learning_rate=2e-5, # Learning rate
per_device_train_batch_size=8, # Batch size for training
per_device_eval_batch_size=8, # Batch size for evaluation
num_train_epochs=3, # Number of training epochs
weight_decay=0.01, # Weight decay for regularization
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset, # Pass eval dataset for evaluation
tokenizer=tokenizer, # Provide tokenizer for model inference
)
# Fine-tune the model
trainer.train()
# Save the model after fine-tuning
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")
# Step 4: Define response function with emergency keyword check
def get_response(user_input):
# Check for emergency/distress keywords
distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"]
is_distress = any(word in user_input.lower() for word in distress_keywords)
# Use Groq API for generating a response
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": user_input}],
model="llama3-8b-8192", # Or replace with another model
)
response = chat_completion.choices[0].message.content
# Append emergency message if distress keywords are detected
if is_distress:
response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]."
return response
# Step 5: Set up Gradio Interface
import gradio as gr
def chatbot_interface(input_text):
return get_response(input_text)
# Launch the Gradio app
gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch()