Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| from bertopic import BERTopic | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from datetime import datetime | |
| import json | |
| from collections import deque | |
| from datasets import load_dataset | |
| class BERTopicChatbot: | |
| #Initialize chatbot with a Hugging Face dataset | |
| #dataset_name: name of the dataset on Hugging Face (e.g., 'vietnam/legal') | |
| #text_column: name of the column containing the text data | |
| #split: which split of the dataset to use ('train', 'test', 'validation') | |
| #max_samples: maximum number of samples to use (to manage memory) | |
| def __init__(self, dataset_name, text_column, split="train", max_samples=10000): | |
| # Initialize BERT sentence transformer | |
| self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Add label mapping | |
| self.label_mapping = { | |
| 0: 'BPD', | |
| 1: 'bipolar', | |
| 2: 'depression', | |
| 3: 'Anxiety', | |
| 4: 'schizophrenia', | |
| 5: 'mentalillness' | |
| } | |
| # Add comfort responses | |
| self.comfort_responses = { | |
| 'BPD': [ | |
| "I understand BPD can be overwhelming. You're not alone in this journey.", | |
| "Your feelings are valid. BPD is challenging, but there are people who understand.", | |
| "Taking things one day at a time with BPD is okay. You're showing great strength." | |
| ], | |
| 'bipolar': [ | |
| "Bipolar disorder can feel like a roller coaster. Remember, stability is possible.", | |
| "You're so strong for managing bipolar disorder. Take it one day at a time.", | |
| "Both the highs and lows are temporary. You've gotten through them before." | |
| ], | |
| 'depression': [ | |
| "Depression is heavy, but you don't have to carry it alone.", | |
| "Even small steps forward are progress. You're doing better than you think.", | |
| "This feeling won't last forever. You've made it through difficult times before." | |
| ], | |
| 'Anxiety': [ | |
| "Your anxiety doesn't define you. You're stronger than your fears.", | |
| "Remember to breathe. You're safe, and this feeling will pass.", | |
| "It's okay to take things at your own pace. You're handling this well." | |
| ], | |
| 'schizophrenia': [ | |
| "You're not your diagnosis. You're a person first, and you matter.", | |
| "Managing schizophrenia takes incredible strength. You're doing well.", | |
| "There's support available, and you deserve all the help you need." | |
| ], | |
| 'mentalillness': [ | |
| "Mental health challenges don't define your worth. You are valuable.", | |
| "Recovery isn't linear, and that's okay. Every step counts.", | |
| "You're not alone in this journey. There's a community that understands." | |
| ] | |
| } | |
| # Load dataset from Hugging Face | |
| try: | |
| dataset = load_dataset(dataset_name, split=split) | |
| # Convert to pandas DataFrame and sample if necessary | |
| if len(dataset) > max_samples: | |
| dataset = dataset.shuffle(seed=42).select(range(max_samples)) | |
| self.df = dataset.to_pandas() | |
| # Ensure text column exists | |
| if text_column not in self.df.columns: | |
| raise ValueError(f"Column '{text_column}' not found in dataset. Available columns: {self.df.columns}") | |
| self.documents = self.df[text_column].tolist() | |
| # Create and train BERTopic model | |
| self.topic_model = BERTopic(embedding_model=self.sentence_model) | |
| self.topics, self.probs = self.topic_model.fit_transform(self.documents) | |
| # Create document embeddings for similarity search | |
| self.doc_embeddings = self.sentence_model.encode(self.documents) | |
| # Initialize metrics storage | |
| self.metrics_history = { | |
| 'similarities': deque(maxlen=100), | |
| 'response_times': deque(maxlen=100), | |
| 'token_counts': deque(maxlen=100), | |
| 'topics_accessed': {} | |
| } | |
| # Store dataset info | |
| self.dataset_info = { | |
| 'name': dataset_name, | |
| 'split': split, | |
| 'total_documents': len(self.documents), | |
| 'topics_found': len(set(self.topics)) | |
| } | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {str(e)}") | |
| raise | |
| def get_metrics_visualizations(self): | |
| """Generate visualizations for chatbot metrics""" | |
| # Similarity trend | |
| fig_similarity = go.Figure() | |
| fig_similarity.add_trace(go.Scatter( | |
| y=list(self.metrics_history['similarities']), | |
| mode='lines+markers', | |
| name='Similarity Score' | |
| )) | |
| fig_similarity.update_layout( | |
| title='Response Similarity Trend', | |
| yaxis_title='Similarity Score', | |
| xaxis_title='Query Number' | |
| ) | |
| # Response time trend | |
| fig_response_time = go.Figure() | |
| fig_response_time.add_trace(go.Scatter( | |
| y=list(self.metrics_history['response_times']), | |
| mode='lines+markers', | |
| name='Response Time' | |
| )) | |
| fig_response_time.update_layout( | |
| title='Response Time Trend', | |
| yaxis_title='Time (seconds)', | |
| xaxis_title='Query Number' | |
| ) | |
| # Token usage trend | |
| fig_tokens = go.Figure() | |
| fig_tokens.add_trace(go.Scatter( | |
| y=list(self.metrics_history['token_counts']), | |
| mode='lines+markers', | |
| name='Token Count' | |
| )) | |
| fig_tokens.update_layout( | |
| title='Token Usage Trend', | |
| yaxis_title='Number of Tokens', | |
| xaxis_title='Query Number' | |
| ) | |
| # Topics accessed pie chart | |
| labels = list(self.metrics_history['topics_accessed'].keys()) | |
| values = list(self.metrics_history['topics_accessed'].values()) | |
| fig_topics = go.Figure(data=[go.Pie(labels=labels, values=values)]) | |
| fig_topics.update_layout(title='Topics Accessed Distribution') | |
| # Make all figures responsive | |
| for fig in [fig_similarity, fig_response_time, fig_tokens, fig_topics]: | |
| fig.update_layout( | |
| autosize=True, | |
| margin=dict(l=20, r=20, t=40, b=20), | |
| height=300 | |
| ) | |
| return fig_similarity, fig_response_time, fig_tokens, fig_topics | |
| def get_most_similar_document(self, query, top_k=3): | |
| # Encode the query | |
| query_embedding = self.sentence_model.encode([query])[0] | |
| # Calculate similarities | |
| similarities = cosine_similarity([query_embedding], self.doc_embeddings)[0] | |
| # Get top k most similar documents | |
| top_indices = similarities.argsort()[-top_k:][::-1] | |
| return [self.documents[i] for i in top_indices], similarities[top_indices] | |
| def get_response(self, user_query): | |
| try: | |
| start_time = datetime.now() | |
| # Get most similar documents | |
| similar_docs, similarities = self.get_most_similar_document(user_query) | |
| # Get the label from the most similar document | |
| most_similar_index = similarities.argmax() | |
| label_index = int(self.df['label'].iloc[most_similar_index]) # Convert to int | |
| condition = self.label_mapping[label_index] # Map the integer label to condition name | |
| # Get comfort response | |
| comfort_messages = self.comfort_responses[condition] | |
| comfort_response = np.random.choice(comfort_messages) | |
| # Calculate query topic for metrics | |
| query_topic, _ = self.topic_model.transform([user_query]) | |
| # Combine information and comfort response | |
| if max(similarities) < 0.5: | |
| response = f"I sense you might be dealing with {condition}. {comfort_response}" | |
| else: | |
| response = f"{similar_docs[0]}\n\n{comfort_response}" | |
| # Track metrics | |
| end_time = datetime.now() | |
| metrics = { | |
| 'similarity': float(max(similarities)), | |
| 'response_time': (end_time - start_time).total_seconds(), | |
| 'tokens': len(response.split()), | |
| 'topic': str(query_topic[0]), | |
| 'detected_condition': condition | |
| } | |
| # Update metrics history | |
| self.metrics_history['similarities'].append(metrics['similarity']) | |
| self.metrics_history['response_times'].append(metrics['response_time']) | |
| self.metrics_history['token_counts'].append(metrics['tokens']) | |
| topic_id = str(query_topic[0]) | |
| self.metrics_history['topics_accessed'][topic_id] = \ | |
| self.metrics_history['topics_accessed'].get(topic_id, 0) + 1 | |
| return response, metrics | |
| except Exception as e: | |
| return f"Error processing query: {str(e)}", {'error': str(e)} | |
| def get_dataset_info(self): | |
| #Return information about the loaded dataset and metrics | |
| try: | |
| return { | |
| 'dataset_info': self.dataset_info, | |
| 'metrics': { | |
| 'avg_similarity': np.mean(list(self.metrics_history['similarities'])) if self.metrics_history['similarities'] else 0, | |
| 'avg_response_time': np.mean(list(self.metrics_history['response_times'])) if self.metrics_history['response_times'] else 0, | |
| 'total_tokens': sum(self.metrics_history['token_counts']), | |
| 'topics_accessed': self.metrics_history['topics_accessed'] | |
| } | |
| } | |
| except Exception as e: | |
| return { | |
| 'error': str(e), | |
| 'dataset_info': None, | |
| 'metrics': None | |
| } | |
| def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000): | |
| return BERTopicChatbot(dataset_name, text_column, split, max_samples) | |
| def main(): | |
| st.title("🤖 Trợ Lý AI - BERTopic") | |
| st.caption("Trò chuyện với chúng mình nhé!") | |
| # Dataset selection sidebar | |
| with st.sidebar: | |
| st.header("Dataset Configuration") | |
| dataset_name = st.text_input( | |
| "Hugging Face Dataset Name", | |
| value="Kanakmi/mental-disorders", | |
| help="Enter the name of a dataset from Hugging Face (e.g., 'Kanakmi/mental-disorders')" | |
| ) | |
| text_column = st.text_input( | |
| "Text Column Name", | |
| value="text", | |
| help="Enter the name of the column containing the text data" | |
| ) | |
| split = st.selectbox( | |
| "Dataset Split", | |
| options=["train", "test", "val", "validation"], | |
| index=0 | |
| ) | |
| max_samples = st.number_input( | |
| "Maximum Samples", | |
| min_value=100, | |
| max_value=100000, | |
| value=10000, | |
| step=1000, | |
| help="Maximum number of samples to load from the dataset" | |
| ) | |
| if st.button("Load Dataset"): | |
| with st.spinner("Loading dataset and initializing model..."): | |
| try: | |
| st.session_state.chatbot = initialize_chatbot( | |
| dataset_name, text_column, split, max_samples | |
| ) | |
| st.success("Dataset loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {str(e)}") | |
| # Initialize session state variables if they don't exist | |
| if 'chatbot' not in st.session_state: | |
| st.session_state.chatbot = None | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Create tabs for chat and metrics | |
| chat_tab, metrics_tab = st.tabs(["Chat", "Metrics"]) | |
| with chat_tab: | |
| # Display existing messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Only show chat input if chatbot is initialized | |
| if st.session_state.chatbot is not None: | |
| if prompt := st.chat_input("Hãy nói gì đó..."): | |
| # Add user message | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Get chatbot response | |
| response, metrics = st.session_state.chatbot.get_response(prompt) | |
| # Add assistant response | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| with st.expander("Response Metrics"): | |
| st.json(metrics) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| else: | |
| st.info("Please load a dataset first to start chatting.") | |
| with metrics_tab: | |
| if st.session_state.chatbot is not None: | |
| try: | |
| # Get visualizations from session state chatbot | |
| fig_similarity, fig_response_time, fig_tokens, fig_topics = st.session_state.chatbot.get_metrics_visualizations() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.plotly_chart(fig_similarity, use_container_width=True) | |
| st.plotly_chart(fig_tokens, use_container_width=True) | |
| with col2: | |
| st.plotly_chart(fig_response_time, use_container_width=True) | |
| st.plotly_chart(fig_topics, use_container_width=True) | |
| # Display statistics | |
| st.subheader("Overall Statistics") | |
| metrics_history = st.session_state.chatbot.metrics_history | |
| if len(metrics_history['similarities']) > 0: | |
| stats_col1, stats_col2, stats_col3 = st.columns(3) | |
| with stats_col1: | |
| st.metric("Avg Similarity", | |
| f"{np.mean(list(metrics_history['similarities'])):.3f}") | |
| with stats_col2: | |
| st.metric("Avg Response Time", | |
| f"{np.mean(list(metrics_history['response_times'])):.3f}s") | |
| with stats_col3: | |
| st.metric("Total Tokens Used", | |
| sum(metrics_history['token_counts'])) | |
| else: | |
| st.info("No chat history available yet. Start a conversation to see metrics.") | |
| except Exception as e: | |
| st.error(f"Error displaying metrics: {str(e)}") | |
| else: | |
| st.info("Please load a dataset first to view metrics.") | |
| if __name__ == "__main__": | |
| main() |