|
import gradio as gr |
|
from datasets import load_dataset |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import random |
|
import groq |
|
|
|
|
|
ds = load_dataset("Amod/mental_health_counseling_conversations") |
|
|
|
|
|
context = ds["train"]["Context"] |
|
response = ds["train"]["Response"] |
|
|
|
|
|
model_name = "t5-small" |
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
api_key = "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" |
|
client = groq.Client(api_key=api_key) |
|
|
|
|
|
def chatbot(user_input): |
|
if not user_input.strip(): |
|
return "Please enter a question or concern to receive guidance." |
|
|
|
|
|
word_count = len(user_input.split()) |
|
max_words = 50 |
|
remaining_words = max_words - word_count |
|
|
|
if remaining_words < 0: |
|
return f"Your input is too long. Please limit to {max_words} words. Words remaining: 0." |
|
|
|
|
|
try: |
|
brief_response = client.predict(user_input) |
|
except Exception as e: |
|
brief_response = None |
|
|
|
if brief_response: |
|
return f"**Personalized Response:** {brief_response}" |
|
|
|
|
|
idx = random.randint(0, len(context) - 1) |
|
context_text = context[idx] |
|
response_text = response[idx] |
|
|
|
|
|
inputs = tokenizer.encode("summarize: " + user_input, return_tensors="pt", max_length=512, truncation=True) |
|
summary_ids = model.generate(inputs, max_length=100, num_beams=4, early_stopping=True) |
|
generated_response = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
if not generated_response: |
|
return "Oops, sorry, I don't have information about your specific problem. Please visit a doctor to prevent mishaps." |
|
|
|
|
|
complete_response = ( |
|
f"**Contextual Information:**\n{context_text}\n\n" |
|
f"**Generated Response:**\n{generated_response}\n\n" |
|
f"**Fallback Response:**\n{response_text}" |
|
) |
|
|
|
return f"{complete_response}\n\nWords entered: {word_count}, Words remaining: {remaining_words}" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=chatbot, |
|
inputs=gr.Textbox( |
|
label="Ask your question:", |
|
placeholder="Describe how you're feeling today...", |
|
lines=4 |
|
), |
|
outputs=gr.Markdown(label="Psychologist Assistant Response"), |
|
title="Virtual Psychiatrist Assistant", |
|
description="Enter your mental health concerns, and receive guidance and responses from a trained assistant.", |
|
theme="huggingface", |
|
) |
|
|
|
|
|
interface.launch() |
|
|
|
|
|
|
|
|