Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
from datasets import load_dataset, Dataset | |
from groq import Groq | |
import os | |
# Initialize Groq client with your API key | |
client = Groq(api_key="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") | |
# Book names (replace with your uploaded book names on Hugging Face) | |
book_names = { | |
"DSM": "Diagnostic_and_statistical_manual_of_mental_disorders_DSM5", | |
"Personality": "Theories_of_Personality_10", | |
"SearchForMeaning": "Mans_Search_For_Meaning" | |
} | |
# Function to load and preprocess the data from books (now using Hugging Face datasets) | |
def load_data(book_names): | |
data = [] | |
for title, book_name in book_names.items(): | |
# Load dataset from Hugging Face using the book name | |
# The dataset should be in the form of a text dataset or you should have pre-uploaded datasets | |
# Example: Assuming the datasets are pre-uploaded on Hugging Face and stored as text files | |
try: | |
dataset = load_dataset(book_name) # Try to load dataset by name | |
text = dataset['train']['text'] # Adjust depending on dataset structure | |
paragraphs = text.split("\n\n") # Split by paragraphs | |
for paragraph in paragraphs: | |
if paragraph.strip(): # Skip empty paragraphs | |
data.append({"text": paragraph.strip()}) | |
except Exception as e: | |
print(f"Error loading dataset for {book_name}: {e}") | |
continue | |
return Dataset.from_list(data) | |
# Load and preprocess dataset for fine-tuning | |
dataset = load_data(book_names) | |
# Load pretrained model and tokenizer from Hugging Face | |
model_name = "gpt2" # Replace with a larger model if needed and feasible | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Set the pad_token to be the same as eos_token (fix for missing padding token) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Tokenize data and create labels (shifted input for causal language modeling) | |
def tokenize_function(examples): | |
# Tokenize the input text | |
encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512) | |
# Create labels by shifting the input ids by one position (for causal LM) | |
labels = encodings["input_ids"].copy() | |
labels = [l if l != tokenizer.pad_token_id else -100 for l in labels] | |
# Return the encodings with labels | |
encodings["labels"] = labels | |
return encodings | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
# Split dataset into train and eval (explicit split for better validation) | |
train_test_split = tokenized_dataset.train_test_split(test_size=0.1) | |
train_dataset = train_test_split["train"] | |
eval_dataset = train_test_split["test"] | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", # Output directory for model and logs | |
eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy | |
learning_rate=2e-5, # Learning rate | |
per_device_train_batch_size=8, # Batch size for training | |
per_device_eval_batch_size=8, # Batch size for evaluation | |
num_train_epochs=3, # Number of training epochs | |
weight_decay=0.01, # Weight decay for regularization | |
) | |
# Initialize the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, # Pass eval dataset for evaluation | |
tokenizer=tokenizer, # Provide tokenizer for model inference | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Save the model after fine-tuning | |
model.save_pretrained("./fine_tuned_model") | |
tokenizer.save_pretrained("./fine_tuned_model") | |
# Step 4: Define response function with emergency keyword check | |
def get_response(user_input): | |
# Check for emergency/distress keywords | |
distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"] | |
is_distress = any(word in user_input.lower() for word in distress_keywords) | |
# Use Groq API for generating a response | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": user_input}], | |
model="llama3-8b-8192", # Or replace with another model | |
) | |
response = chat_completion.choices[0].message.content | |
# Append emergency message if distress keywords are detected | |
if is_distress: | |
response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]." | |
return response | |
# Step 5: Set up Gradio Interface | |
import gradio as gr | |
def chatbot_interface(input_text): | |
return get_response(input_text) | |
# Launch the Gradio app | |
gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch() | |