Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
from datasets import Dataset | |
from groq import Groq | |
import os | |
# Initialize Groq client with your API key | |
client = Groq(api_key="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") | |
# Paths to your books (assuming files are in the root directory of the repo) | |
book_paths = { | |
"DSM": "./DSM5.pdf", | |
"Personality": "./TheoriesofPersonality.pdf", | |
"SearchForMeaning": "./MansSearchForMeaning.pdf" | |
} | |
# Function to load and preprocess the data from books | |
def load_data(paths): | |
data = [] | |
for title, path in paths.items(): | |
print(f"Attempting to load file for {title} from path: {path}") | |
try: | |
with open(path, "r", encoding="utf-8", errors='ignore') as file: | |
text = file.read() | |
paragraphs = text.split("\n\n") # Split by paragraphs (adjust as needed) | |
for paragraph in paragraphs: | |
if paragraph.strip(): # Skip empty paragraphs | |
data.append({"text": paragraph.strip()}) | |
except FileNotFoundError: | |
print(f"Error: File for {title} not found at path {path}") | |
return Dataset.from_list(data) | |
# Load and preprocess dataset for fine-tuning | |
dataset = load_data(book_paths) | |
# Load pretrained model and tokenizer from Hugging Face | |
model_name = "gpt2" # Replace with a larger model if needed and feasible | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Set the pad_token to be the same as eos_token (fix for missing padding token) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Tokenize data and create labels (shifted input for causal language modeling) | |
def tokenize_function(examples): | |
# Tokenize the input text | |
encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512) | |
# Create labels by shifting the input ids by one position (for causal LM) | |
labels = encodings["input_ids"].copy() | |
labels = [l if l != tokenizer.pad_token_id else -100 for l in labels] | |
# Return the encodings with labels | |
encodings["labels"] = labels | |
return encodings | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
# Split dataset into train and eval (explicit split for better validation) | |
train_test_split = tokenized_dataset.train_test_split(test_size=0.1) | |
train_dataset = train_test_split["train"] | |
eval_dataset = train_test_split["test"] | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", # Output directory for model and logs | |
eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy | |
learning_rate=2e-5, # Learning rate | |
per_device_train_batch_size=8, # Batch size for training | |
per_device_eval_batch_size=8, # Batch size for evaluation | |
num_train_epochs=3, # Number of training epochs | |
weight_decay=0.01, # Weight decay for regularization | |
) | |
# Initialize the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, # Pass eval dataset for evaluation | |
tokenizer=tokenizer, # Provide tokenizer for model inference | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Save the model after fine-tuning | |
model.save_pretrained("./fine_tuned_model") | |
tokenizer.save_pretrained("./fine_tuned_model") | |
# Step 4: Define response function with emergency keyword check | |
def get_response(user_input): | |
# Check for emergency/distress keywords | |
distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"] | |
is_distress = any(word in user_input.lower() for word in distress_keywords) | |
# Use Groq API for generating a response | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": user_input}], | |
model="llama3-8b-8192", # Or replace with another model | |
) | |
response = chat_completion.choices[0].message.content | |
# Append emergency message if distress keywords are detected | |
if is_distress: | |
response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]." | |
return response | |
# Step 5: Set up Gradio Interface | |
import gradio as gr | |
def chatbot_interface(input_text): | |
return get_response(input_text) | |
# Launch the Gradio app | |
gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch() | |