abc / app.py
shibam007's picture
Update app.py
d1d46f8 verified
import streamlit as st
import google.generativeai as genai
import chromadb
import os
import time
from transformers import pipeline
from dotenv import load_dotenv
# Load API key from .env
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Configure Gemini AI
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel("gemini-1.5-pro")
# Initialize ChromaDB for RAG
client = chromadb.PersistentClient(path="./mental_health_memory")
collection = client.get_or_create_collection(name="chat_history")
# Load Sentiment Analysis Model
sentiment_pipeline = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True)
# Function to analyze sentiment
def analyze_sentiment(text):
results = sentiment_pipeline(text)
emotions = {res["label"]: res["score"] for res in results[0]}
return max(emotions, key=emotions.get), emotions
# Function to store chat history in ChromaDB
def store_chat(user_input, bot_response):
collection.add(
documents=[user_input, bot_response],
metadatas=[{"role": "user"}, {"role": "bot"}],
ids=[str(len(collection.get())) + "_user", str(len(collection.get())) + "_bot"]
)
# Function to retrieve relevant past messages (RAG)
def retrieve_context():
history = collection.get()
if len(history["documents"]) > 3:
return history["documents"][-3:]
return history["documents"]
# Function to generate a response using LLM with past context
def get_gemini_response(user_input):
past_context = retrieve_context()
full_prompt = f"Previous Chat Context: {past_context}\nUser: {user_input}\nBot:"
try:
response = model.generate_content(full_prompt)
return response.text
except Exception as e:
return f"Sorry, I encountered an issue. Error: {str(e)}"
# Streamlit UI
st.title("\U0001F9E0 Mental Health Chatbot")
# Chat history session
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Display chat history
for role, message in st.session_state.chat_history:
with st.chat_message(role):
st.write(message)
# User input
user_input = st.chat_input("Ask me anything about mental health...")
if user_input:
st.chat_message("user").write(user_input)
# Analyze sentiment
emotion, scores = analyze_sentiment(user_input)
# Generate chatbot response with RAG
with st.spinner("Thinking..."):
bot_response = get_gemini_response(user_input)
st.chat_message("assistant").write(bot_response)
# Store chat history
st.session_state.chat_history.append(("user", user_input))
st.session_state.chat_history.append(("assistant", bot_response))
store_chat(user_input, bot_response)
# Sidebar: Clear Chat & Instructions
with st.sidebar:
if st.button("Clear Chat"):
st.session_state.chat_history = []
collection.delete(ids=collection.get()["ids"])
st.success("Chat cleared! Refreshing...")
st.rerun()
st.markdown("---")
st.markdown("*Note: This chatbot is for informational purposes only and should not replace professional mental health advice.*")
# Sentiment Analysis Button
if st.button("Sentiment Analysis"):
chat_text = " ".join([msg for _, msg in st.session_state.chat_history])
if chat_text:
detected_emotion, emotion_scores = analyze_sentiment(chat_text)
st.subheader("Sentiment Analysis Result")
st.write(f"Detected Emotion: **{detected_emotion}**")
st.json(emotion_scores)
else:
st.warning("No chat history found for analysis.")