Spaces:
Sleeping
Sleeping
import streamlit as st | |
import requests | |
from bertopic import BERTopic | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
import pandas as pd | |
import plotly.graph_objects as go | |
from datetime import datetime | |
import json | |
from collections import deque | |
from datasets import load_dataset | |
class BERTopicChatbot: | |
#Initialize chatbot with a Hugging Face dataset | |
#dataset_name: name of the dataset on Hugging Face (e.g., 'vietnam/legal') | |
#text_column: name of the column containing the text data | |
#split: which split of the dataset to use ('train', 'test', 'validation') | |
#max_samples: maximum number of samples to use (to manage memory) | |
def __init__(self, dataset_name, text_column, split="train", max_samples=10000): | |
# Initialize BERT sentence transformer | |
self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Add label mapping | |
self.label_mapping = { | |
0: 'BPD', | |
1: 'bipolar', | |
2: 'depression', | |
3: 'Anxiety', | |
4: 'schizophrenia', | |
5: 'mentalillness' | |
} | |
# Add comfort responses | |
self.comfort_responses = { | |
'BPD': [ | |
"I understand BPD can be overwhelming. You're not alone in this journey.", | |
"Your feelings are valid. BPD is challenging, but there are people who understand.", | |
"Taking things one day at a time with BPD is okay. You're showing great strength." | |
], | |
'bipolar': [ | |
"Bipolar disorder can feel like a roller coaster. Remember, stability is possible.", | |
"You're so strong for managing bipolar disorder. Take it one day at a time.", | |
"Both the highs and lows are temporary. You've gotten through them before." | |
], | |
'depression': [ | |
"Depression is heavy, but you don't have to carry it alone.", | |
"Even small steps forward are progress. You're doing better than you think.", | |
"This feeling won't last forever. You've made it through difficult times before." | |
], | |
'Anxiety': [ | |
"Your anxiety doesn't define you. You're stronger than your fears.", | |
"Remember to breathe. You're safe, and this feeling will pass.", | |
"It's okay to take things at your own pace. You're handling this well." | |
], | |
'schizophrenia': [ | |
"You're not your diagnosis. You're a person first, and you matter.", | |
"Managing schizophrenia takes incredible strength. You're doing well.", | |
"There's support available, and you deserve all the help you need." | |
], | |
'mentalillness': [ | |
"Mental health challenges don't define your worth. You are valuable.", | |
"Recovery isn't linear, and that's okay. Every step counts.", | |
"You're not alone in this journey. There's a community that understands." | |
] | |
} | |
# Load dataset from Hugging Face | |
try: | |
dataset = load_dataset(dataset_name, split=split) | |
# Convert to pandas DataFrame and sample if necessary | |
if len(dataset) > max_samples: | |
dataset = dataset.shuffle(seed=42).select(range(max_samples)) | |
self.df = dataset.to_pandas() | |
# Ensure text column exists | |
if text_column not in self.df.columns: | |
raise ValueError(f"Column '{text_column}' not found in dataset. Available columns: {self.df.columns}") | |
self.documents = self.df[text_column].tolist() | |
# Create and train BERTopic model | |
self.topic_model = BERTopic(embedding_model=self.sentence_model) | |
self.topics, self.probs = self.topic_model.fit_transform(self.documents) | |
# Create document embeddings for similarity search | |
self.doc_embeddings = self.sentence_model.encode(self.documents) | |
# Initialize metrics storage | |
self.metrics_history = { | |
'similarities': deque(maxlen=100), | |
'response_times': deque(maxlen=100), | |
'token_counts': deque(maxlen=100), | |
'topics_accessed': {} | |
} | |
# Store dataset info | |
self.dataset_info = { | |
'name': dataset_name, | |
'split': split, | |
'total_documents': len(self.documents), | |
'topics_found': len(set(self.topics)) | |
} | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
raise | |
def get_metrics_visualizations(self): | |
"""Generate visualizations for chatbot metrics""" | |
# Similarity trend | |
fig_similarity = go.Figure() | |
fig_similarity.add_trace(go.Scatter( | |
y=list(self.metrics_history['similarities']), | |
mode='lines+markers', | |
name='Similarity Score' | |
)) | |
fig_similarity.update_layout( | |
title='Response Similarity Trend', | |
yaxis_title='Similarity Score', | |
xaxis_title='Query Number' | |
) | |
# Response time trend | |
fig_response_time = go.Figure() | |
fig_response_time.add_trace(go.Scatter( | |
y=list(self.metrics_history['response_times']), | |
mode='lines+markers', | |
name='Response Time' | |
)) | |
fig_response_time.update_layout( | |
title='Response Time Trend', | |
yaxis_title='Time (seconds)', | |
xaxis_title='Query Number' | |
) | |
# Token usage trend | |
fig_tokens = go.Figure() | |
fig_tokens.add_trace(go.Scatter( | |
y=list(self.metrics_history['token_counts']), | |
mode='lines+markers', | |
name='Token Count' | |
)) | |
fig_tokens.update_layout( | |
title='Token Usage Trend', | |
yaxis_title='Number of Tokens', | |
xaxis_title='Query Number' | |
) | |
# Topics accessed pie chart | |
labels = list(self.metrics_history['topics_accessed'].keys()) | |
values = list(self.metrics_history['topics_accessed'].values()) | |
fig_topics = go.Figure(data=[go.Pie(labels=labels, values=values)]) | |
fig_topics.update_layout(title='Topics Accessed Distribution') | |
# Make all figures responsive | |
for fig in [fig_similarity, fig_response_time, fig_tokens, fig_topics]: | |
fig.update_layout( | |
autosize=True, | |
margin=dict(l=20, r=20, t=40, b=20), | |
height=300 | |
) | |
return fig_similarity, fig_response_time, fig_tokens, fig_topics | |
def get_most_similar_document(self, query, top_k=3): | |
# Encode the query | |
query_embedding = self.sentence_model.encode([query])[0] | |
# Calculate similarities | |
similarities = cosine_similarity([query_embedding], self.doc_embeddings)[0] | |
# Get top k most similar documents | |
top_indices = similarities.argsort()[-top_k:][::-1] | |
return [self.documents[i] for i in top_indices], similarities[top_indices] | |
def get_response(self, user_query): | |
try: | |
start_time = datetime.now() | |
# Get most similar documents | |
similar_docs, similarities = self.get_most_similar_document(user_query) | |
# Get the label from the most similar document | |
most_similar_index = similarities.argmax() | |
label_index = int(self.df['label'].iloc[most_similar_index]) # Convert to int | |
condition = self.label_mapping[label_index] # Map the integer label to condition name | |
# Get comfort response | |
comfort_messages = self.comfort_responses[condition] | |
comfort_response = np.random.choice(comfort_messages) | |
# Calculate query topic for metrics | |
query_topic, _ = self.topic_model.transform([user_query]) | |
# Combine information and comfort response | |
if max(similarities) < 0.5: | |
response = f"I sense you might be dealing with {condition}. {comfort_response}" | |
else: | |
response = f"{similar_docs[0]}\n\n{comfort_response}" | |
# Track metrics | |
end_time = datetime.now() | |
metrics = { | |
'similarity': float(max(similarities)), | |
'response_time': (end_time - start_time).total_seconds(), | |
'tokens': len(response.split()), | |
'topic': str(query_topic[0]), | |
'detected_condition': condition | |
} | |
# Update metrics history | |
self.metrics_history['similarities'].append(metrics['similarity']) | |
self.metrics_history['response_times'].append(metrics['response_time']) | |
self.metrics_history['token_counts'].append(metrics['tokens']) | |
topic_id = str(query_topic[0]) | |
self.metrics_history['topics_accessed'][topic_id] = \ | |
self.metrics_history['topics_accessed'].get(topic_id, 0) + 1 | |
return response, metrics | |
except Exception as e: | |
return f"Error processing query: {str(e)}", {'error': str(e)} | |
def get_dataset_info(self): | |
#Return information about the loaded dataset and metrics | |
try: | |
return { | |
'dataset_info': self.dataset_info, | |
'metrics': { | |
'avg_similarity': np.mean(list(self.metrics_history['similarities'])) if self.metrics_history['similarities'] else 0, | |
'avg_response_time': np.mean(list(self.metrics_history['response_times'])) if self.metrics_history['response_times'] else 0, | |
'total_tokens': sum(self.metrics_history['token_counts']), | |
'topics_accessed': self.metrics_history['topics_accessed'] | |
} | |
} | |
except Exception as e: | |
return { | |
'error': str(e), | |
'dataset_info': None, | |
'metrics': None | |
} | |
def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000): | |
return BERTopicChatbot(dataset_name, text_column, split, max_samples) | |
def main(): | |
st.title("🤖 Trợ Lý AI - BERTopic") | |
st.caption("Trò chuyện với chúng mình nhé!") | |
# Dataset selection sidebar | |
with st.sidebar: | |
st.header("Dataset Configuration") | |
dataset_name = st.text_input( | |
"Hugging Face Dataset Name", | |
value="Kanakmi/mental-disorders", | |
help="Enter the name of a dataset from Hugging Face (e.g., 'Kanakmi/mental-disorders')" | |
) | |
text_column = st.text_input( | |
"Text Column Name", | |
value="text", | |
help="Enter the name of the column containing the text data" | |
) | |
split = st.selectbox( | |
"Dataset Split", | |
options=["train", "test", "val", "validation"], | |
index=0 | |
) | |
max_samples = st.number_input( | |
"Maximum Samples", | |
min_value=100, | |
max_value=100000, | |
value=10000, | |
step=1000, | |
help="Maximum number of samples to load from the dataset" | |
) | |
if st.button("Load Dataset"): | |
with st.spinner("Loading dataset and initializing model..."): | |
try: | |
st.session_state.chatbot = initialize_chatbot( | |
dataset_name, text_column, split, max_samples | |
) | |
st.success("Dataset loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
# Initialize session state variables if they don't exist | |
if 'chatbot' not in st.session_state: | |
st.session_state.chatbot = None | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Create tabs for chat and metrics | |
chat_tab, metrics_tab = st.tabs(["Chat", "Metrics"]) | |
with chat_tab: | |
# Display existing messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Only show chat input if chatbot is initialized | |
if st.session_state.chatbot is not None: | |
if prompt := st.chat_input("Hãy nói gì đó..."): | |
# Add user message | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Get chatbot response | |
response, metrics = st.session_state.chatbot.get_response(prompt) | |
# Add assistant response | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
with st.expander("Response Metrics"): | |
st.json(metrics) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
else: | |
st.info("Please load a dataset first to start chatting.") | |
with metrics_tab: | |
if st.session_state.chatbot is not None: | |
try: | |
# Get visualizations from session state chatbot | |
fig_similarity, fig_response_time, fig_tokens, fig_topics = st.session_state.chatbot.get_metrics_visualizations() | |
col1, col2 = st.columns(2) | |
with col1: | |
st.plotly_chart(fig_similarity, use_container_width=True) | |
st.plotly_chart(fig_tokens, use_container_width=True) | |
with col2: | |
st.plotly_chart(fig_response_time, use_container_width=True) | |
st.plotly_chart(fig_topics, use_container_width=True) | |
# Display statistics | |
st.subheader("Overall Statistics") | |
metrics_history = st.session_state.chatbot.metrics_history | |
if len(metrics_history['similarities']) > 0: | |
stats_col1, stats_col2, stats_col3 = st.columns(3) | |
with stats_col1: | |
st.metric("Avg Similarity", | |
f"{np.mean(list(metrics_history['similarities'])):.3f}") | |
with stats_col2: | |
st.metric("Avg Response Time", | |
f"{np.mean(list(metrics_history['response_times'])):.3f}s") | |
with stats_col3: | |
st.metric("Total Tokens Used", | |
sum(metrics_history['token_counts'])) | |
else: | |
st.info("No chat history available yet. Start a conversation to see metrics.") | |
except Exception as e: | |
st.error(f"Error displaying metrics: {str(e)}") | |
else: | |
st.info("Please load a dataset first to view metrics.") | |
if __name__ == "__main__": | |
main() |