Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForSequenceClassification | |
import random | |
import torch | |
import groq # Assuming you are using the Groq library | |
import os | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
# Load environment variables from .env file | |
load_dotenv() | |
HUGGING_FACE_TOKEN = os.getenv("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") | |
# Authenticate with Hugging Face (use your token) | |
login(HUGGING_FACE_TOKEN) | |
# Load the mental health counseling conversations dataset | |
ds = load_dataset("Amod/mental_health_counseling_conversations") | |
context = ds["train"]["Context"] | |
response = ds["train"]["Response"] | |
GROQ_API_KEY = "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" | |
client = groq.Groq(api_key=GROQ_API_KEY) | |
# Load FLAN-T5 model and tokenizer for primary RAG | |
flan_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small") | |
flan_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small") | |
# Load sentiment analysis model | |
sentiment_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
sentiment_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
# Groq client setup (assuming you have an API key) | |
client = groq.Groq(api_key=GROQ_API_KEY) # Corrected Groq client initialization | |
# Function for sentiment analysis | |
def analyze_sentiment(text): | |
inputs = sentiment_tokenizer(text, return_tensors="pt") | |
outputs = sentiment_model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
sentiment = "positive" if torch.argmax(probs) == 1 else "negative" | |
confidence = probs.max().item() | |
return sentiment, confidence | |
# Function to generate response based on sentiment and user input | |
def generate_response(sentiment, user_input): | |
prompt = f"The user feels {sentiment}. Respond with supportive advice based on: {user_input}" | |
inputs = flan_tokenizer(prompt, return_tensors="pt") | |
response = flan_model.generate(**inputs, max_length=150) | |
return flan_tokenizer.decode(response[0], skip_special_tokens=True) | |
# Main chatbot function | |
def chatbot(user_input): | |
if not user_input.strip(): | |
return "Please enter a question or concern to receive guidance." | |
# Word count limit | |
word_count = len(user_input.split()) | |
max_words = 50 | |
remaining_words = max_words - word_count | |
if remaining_words < 0: | |
return f"Your input is too long. Please limit it to {max_words} words." | |
# Sentiment analysis | |
sentiment, confidence = analyze_sentiment(user_input) | |
# Groq API fallback for a personalized response | |
try: | |
brief_response = client.chat.completions.create( | |
messages=[{ | |
"role": "user", | |
"content": user_input, | |
}], | |
model="llama3-8b-8192", # Change model if needed | |
) | |
brief_response = brief_response.choices[0].message.content | |
except Exception as e: | |
brief_response = None | |
if brief_response: | |
return f"**Personalized Response from Groq:** {brief_response}" | |
# Fallback to FLAN-T5 model for response generation | |
response_text = generate_response(sentiment, user_input) | |
def generate_response(user_input): | |
# Generate response using FLAN-T5 | |
inputs = flan_tokenizer.encode("summarize: " + user_input, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = flan_model.generate(inputs, max_length=100, num_beams=4, early_stopping=True) | |
generated_response = flan_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
if not generated_response: | |
return "I'm sorry, I don't have information specific to your concern. Please consult a professional." | |
# Final response with different sources | |
complete_response = ( | |
f"**Sentiment Analysis:** {sentiment} (Confidence: {confidence:.2f})\n\n" | |
f"**Generated Response:**\n{generated_response}\n\n" | |
f"**Contextual Information:**\n{context_text}\n\n" | |
f"**Additional Dataset Response:**\n{dataset_response}\n\n" | |
f"Words entered: {word_count}, Words remaining: {remaining_words}" | |
) | |
return complete_response | |
# Example call to the function | |
response = generate_response("This is an example input.") | |
print(response) | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=chatbot, | |
inputs=gr.Textbox( | |
label="Ask your question:", | |
placeholder="Describe how you're feeling today...", | |
lines=4 | |
), | |
outputs=gr.Markdown(label="Psychologist Assistant Response"), | |
title="Virtual Psychiatrist Assistant", | |
description="Enter your mental health concerns, and receive guidance and responses from a trained assistant.", | |
theme="huggingface" | |
) | |
# Launch the app | |
interface.launch() | |