Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| from utils.consts import DB_PATH | |
| import sqlite3 | |
| import re | |
| import os | |
| from agents.sql_agent.agent import SQLAgent | |
| import time | |
| from agents.tools import PlotSQLTool | |
| from agents.dataframe_agent import get_dataframe_agent | |
| from datetime import datetime | |
| db_name = os.path.basename(DB_PATH) | |
| st.set_page_config(page_title="🔍 TalkToData", layout="wide", initial_sidebar_state="collapsed") | |
| # Loại bỏ title markdown để tránh hiển thị lặp lại | |
| # Sidebar for settings | |
| with st.sidebar: | |
| st.header("ℹ️ About", anchor=None) | |
| st.markdown(""" | |
| **TalkToData** v0.1.0 | |
| Your personal AI Data Analyst. | |
| """, unsafe_allow_html=True) | |
| # Initialize chat history | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Initialize SQL agent | |
| # agent = get_sql_agent() | |
| agent = SQLAgent() | |
| state = { | |
| "question": None, | |
| "db_info": { | |
| "tables": [], | |
| "columns": {}, | |
| "schema": None | |
| }, | |
| "sql_query": None, | |
| "sql_result": None, | |
| "error": None, | |
| "step": None, | |
| "answer": None | |
| } | |
| # --- Upload Screen State --- | |
| if 'files_uploaded' not in st.session_state: | |
| st.session_state['files_uploaded'] = False | |
| # TEMP: Bypass landing page | |
| st.session_state['files_uploaded'] = True | |
| if not st.session_state['files_uploaded']: | |
| # CSS to center and enlarge only the welcome start button | |
| st.markdown(""" | |
| <style> | |
| .welcome .stButton { display: flex; justify-content: center; } | |
| .welcome .stButton button { font-size:2.5rem !important; padding:1.25rem 2rem !important; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Wrap welcome content to scope styling | |
| st.markdown("<div class='welcome' style='max-width:600px;margin:auto;text-align:center;'>", unsafe_allow_html=True) | |
| # Title and subtitle | |
| st.markdown(""" | |
| <h1 style='text-align:center; margin-bottom:0;'>🔍 TalkToData</h1> | |
| <h3 style='text-align:center; color:gray;'>Your Personal AI Data Analyst that instantly answers your data questions with clear insights and elegant visualizations.</h3> | |
| """, unsafe_allow_html=True) | |
| # Standalone welcome start button | |
| if st.button("🚀 Explore now", key="start"): | |
| st.session_state['files_uploaded'] = True | |
| st.experimental_rerun() | |
| # Close welcome wrapper | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| st.divider() | |
| # SaaS-style Features section | |
| st.markdown("## Features") | |
| feat_cols = st.columns(3) | |
| feat_cols[0].markdown("### 🗣 Natural-Language Queries\nAsk your data without SQL knowledge.") | |
| feat_cols[1].markdown("### 📊 Instant Visualizations\nGet charts from one command.") | |
| feat_cols[2].markdown("### 🔒 Secure & Local\nYour data stays on your machine.") | |
| st.divider() | |
| # How It Works section | |
| st.markdown("## How It Works") | |
| step_cols = st.columns(3) | |
| step_cols[0].markdown("#### 1️⃣ Upload\nUpload .db or CSV files.") | |
| step_cols[1].markdown("#### 2️⃣ Chat\nInteract in natural language.") | |
| step_cols[2].markdown("#### 3️⃣ Visualize\nSee results as tables or charts.") | |
| st.divider() | |
| # Use Cases | |
| st.markdown("## Use Cases") | |
| st.markdown("- \"Show me top 5 products by sales\" → Chart") | |
| st.markdown("- \"List customers from 2020\" → Table") | |
| st.divider() | |
| # Testimonials | |
| st.markdown("## Testimonials") | |
| testi_cols = st.columns(2) | |
| testi_cols[0].markdown("> \"TalkToData transformed our data workflow!\" \n— Jane Doe, Data Analyst") | |
| testi_cols[1].markdown("> \"The AI assistant is incredibly smart and fast.\" \n— John Smith, Product Manager") | |
| st.divider() | |
| # Footer | |
| st.markdown("2025 TalkToData. All rights reserved.") | |
| st.markdown("<p style='text-align: center; color: gray;'>TalkToData v0.1.0 - Copyright 2025 by <a href='https://github.com/phamdinhkhanh'>Khanh Pham</a></p>", unsafe_allow_html=True) | |
| st.html( | |
| "<p><span style='text-decoration: line-through double red;'>Oops</span>!</p>" | |
| ) | |
| st.divider() | |
| else: | |
| # App title and return button | |
| # st.title("🔍 TalkToData") | |
| st.markdown("### TalkToData") | |
| # TEMP: Commented out back-to-home | |
| # if st.button('⬅️ Back to Home', key='back_to_upload'): | |
| # st.session_state['files_uploaded'] = False | |
| # # Xóa dữ liệu cũ | |
| # if 'uploaded_csvs' in st.session_state: | |
| # st.session_state['uploaded_csvs'] = [] | |
| # st.experimental_rerun() | |
| # Layout: Data source selector, main content, and chat | |
| data_col, left_col, right_col = st.columns([1.5, 3, 2]) | |
| # Data source selection | |
| with data_col: | |
| # st.subheader("Data Sources") | |
| # Upload data | |
| with st.expander("**Upload Data**", expanded=True): | |
| st.file_uploader('Select SQLite (.db), CSV or Excel (.xlsx) files', | |
| type=['db', 'csv', 'xlsx'], | |
| accept_multiple_files=True, | |
| key='upload_any_col', | |
| label_visibility="collapsed") | |
| gsheet_url = st.text_input('Enter Google Sheets URL (optional)', '', key='gsheet_url') | |
| upload_status = [] | |
| has_db = False | |
| has_csv = False | |
| # Retrieve uploaded files list safely | |
| uploaded_files = st.session_state.get('upload_any_col', []) | |
| # Process Google Sheets if URL provided | |
| url = st.session_state.get('gsheet_url', '').strip() | |
| if url: | |
| try: | |
| csv_url = url.replace('/edit#gid=', '/export?format=csv&gid=') | |
| df_gs = pd.read_csv(csv_url) | |
| if 'uploaded_csvs' not in st.session_state: | |
| st.session_state['uploaded_csvs'] = [] | |
| st.session_state['uploaded_csvs'].append({'name': 'GoogleSheets', 'df': df_gs}) | |
| upload_status.append('✅ Google Sheets loaded') | |
| has_csv = True | |
| except Exception as e: | |
| upload_status.append(f'❌ Google Sheets error: {e}') | |
| # Process files | |
| for f in uploaded_files: | |
| if f.name.lower().endswith('.db'): | |
| try: | |
| with open(DB_PATH, "wb") as dbf: | |
| dbf.write(f.read()) | |
| upload_status.append(f"✅ Database: {f.name}") | |
| has_db = True | |
| except Exception as e: | |
| upload_status.append(f"❌ Database error: {e}") | |
| # Process CSV and Excel | |
| name = f.name.lower() | |
| if name.endswith('.csv') or name.endswith('.xlsx'): | |
| try: | |
| if name.endswith('.xlsx'): | |
| # Process each sheet in Excel | |
| f.seek(0) | |
| xls = pd.ExcelFile(f) | |
| sheets = st.multiselect(f"Select sheet(s) from {f.name}", xls.sheet_names, default=xls.sheet_names) | |
| for sheet in sheets: | |
| # Read raw to detect header rows | |
| raw = xls.parse(sheet, header=None) | |
| nn = raw.notnull().sum(axis=1) | |
| hdr = [i for i, cnt in enumerate(nn) if cnt > 1] | |
| if len(hdr) >= 2: | |
| header = hdr[:2] | |
| elif len(hdr) == 1: | |
| header = [hdr[0]] | |
| else: | |
| header = [0] | |
| df_sheet = xls.parse(sheet, header=header) | |
| # Flatten MultiIndex if needed | |
| if isinstance(df_sheet.columns, pd.MultiIndex): | |
| df_sheet.columns = [" ".join([str(x) for x in col if pd.notna(x)]).strip() for col in df_sheet.columns] | |
| # Store with sheet label | |
| sheet_key = f"{f.name}:{sheet}" | |
| if 'uploaded_csvs' not in st.session_state: | |
| st.session_state['uploaded_csvs'] = [] | |
| st.session_state['uploaded_csvs'].append({'name': sheet_key, 'df': df_sheet}) | |
| upload_status.append(f"✅ Excel: {sheet_key}") | |
| else: | |
| temp_df = pd.read_csv(f) | |
| if 'uploaded_csvs' not in st.session_state: | |
| st.session_state['uploaded_csvs'] = [] | |
| # Check existing and update | |
| csv_exists = False | |
| for i, csv in enumerate(st.session_state['uploaded_csvs']): | |
| if csv['name'] == f.name: | |
| st.session_state['uploaded_csvs'][i]['df'] = temp_df | |
| csv_exists = True | |
| break | |
| if not csv_exists: | |
| st.session_state['uploaded_csvs'].append({'name': f.name, 'df': temp_df}) | |
| upload_status.append(f"✅ CSV/Excel: {f.name}") | |
| has_csv = True | |
| except Exception as e: | |
| upload_status.append(f"❌ CSV/Excel error: {e}") | |
| # Hiển thị trạng thái upload | |
| if upload_status: | |
| for status in upload_status: | |
| st.write(status) | |
| # After upload, select data sources | |
| ds = [] | |
| if os.path.exists(DB_PATH) and os.path.getsize(DB_PATH) > 0: | |
| ds.append(db_name) | |
| if 'uploaded_csvs' in st.session_state: | |
| ds += [csv['name'] for csv in st.session_state['uploaded_csvs']] | |
| if ds: | |
| # Initialize selected_sources session state to default to db_name | |
| if 'selected_sources' not in st.session_state: | |
| st.session_state['selected_sources'] = [db_name] if db_name in ds else [] | |
| selected_sources = st.multiselect( | |
| "**Select sources**", options=ds, | |
| key='selected_sources' | |
| ) | |
| else: | |
| st.info("Upload a database or CSV/Excel file to select a data source.") | |
| with left_col: | |
| # Data Preview: filter sources by user selection | |
| selected = st.session_state.get('selected_sources', []) | |
| preview_db = os.path.exists(DB_PATH) and db_name in selected | |
| # Filter CSV/Excel previews | |
| preview_csvs = [csv for csv in st.session_state.get('uploaded_csvs', []) if csv['name'] in selected] | |
| if preview_db or preview_csvs: | |
| # Display previews | |
| with st.container(height=415): | |
| st.markdown("**Data Preview**") | |
| # Build tab labels | |
| tab_labels = [] | |
| if preview_db: | |
| tab_labels.append(db_name) | |
| for c in preview_csvs: | |
| tab_labels.append(c['name']) | |
| tabs = st.tabs(tab_labels) | |
| idx = 0 | |
| # Database preview | |
| if preview_db: | |
| with tabs[idx]: | |
| conn = sqlite3.connect(DB_PATH) | |
| tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall() | |
| if tables: | |
| t_tabs = st.tabs([t[0] for t in tables]) | |
| for t, tab in zip(tables, t_tabs): | |
| with tab: | |
| st.table(pd.read_sql_query(f"SELECT * FROM {t[0]}", conn)) | |
| else: | |
| st.info("No tables found.") | |
| conn.close() | |
| idx += 1 | |
| # CSV/Excel previews | |
| for c in preview_csvs: | |
| with tabs[idx]: | |
| st.table(c['df']) | |
| idx += 1 | |
| # --- Data Exploration Section (Always Visible) --- | |
| with st.container(height=225): | |
| # Data Exploration: only support Database source | |
| selected = st.session_state.get('selected_sources', []) | |
| if db_name not in selected: | |
| st.warning(f"⚠️ Data Exploration only supports SQL queries on database .db files. Please select at least a database to continue.") | |
| else: | |
| # st.subheader("Data Exploration") | |
| sql_explore = st.text_area( | |
| "Enter SQL query to explore:", | |
| value=st.session_state.get('explore_sql', ''), | |
| height=100, | |
| key='explore_sql' | |
| ) | |
| if st.button("Run Query", key="explore_run"): | |
| try: | |
| df_explore = pd.read_sql_query(sql_explore, sqlite3.connect(DB_PATH)) | |
| st.session_state['explore_result'] = df_explore | |
| # Record exploration history | |
| if 'explore_history' not in st.session_state: | |
| st.session_state['explore_history'] = [] | |
| # User query | |
| st.session_state['explore_history'].append({ | |
| 'source': 'explore', 'role': 'user', 'content': sql_explore, 'timestamp': datetime.now() | |
| }) | |
| # Assistant result as CSV | |
| res_str = df_explore.to_csv(index=False) | |
| st.session_state['explore_history'].append({ | |
| 'source': 'explore', 'role': 'assistant', 'content': res_str, 'timestamp': datetime.now() | |
| }) | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| # Wrap tabs in scrollable container | |
| with st.container(height=300): | |
| # st.markdown("<div style='height:300px; overflow:auto'>", unsafe_allow_html=True) | |
| tabs = st.tabs(["Results", "History"]) | |
| # Results tab: show explore_result only | |
| with tabs[0]: | |
| if 'explore_result' in st.session_state: | |
| # st.subheader("Results") | |
| st.table(st.session_state['explore_result']) | |
| else: | |
| st.write("No results yet.") | |
| # History tab: Query history | |
| with tabs[1]: | |
| # st.subheader("History") | |
| # Build paired history entries | |
| combined = [] | |
| # Exploration history pairs | |
| explore_hist = st.session_state.get('explore_history', []) | |
| for i in range(0, len(explore_hist), 2): | |
| u = explore_hist[i] if i < len(explore_hist) else {} | |
| a = explore_hist[i+1] if i+1 < len(explore_hist) else {} | |
| combined.append({ | |
| 'source': db_name, | |
| 'query_type': 'sql', | |
| 'query': u.get('content'), | |
| 'result': a.get('content'), | |
| 'timestamp': u.get('timestamp') | |
| }) | |
| # Chat history pairs for all sources | |
| for source, chat_hist in st.session_state.get('chat_histories', {}).items(): | |
| for idx in range(len(chat_hist)): | |
| if chat_hist[idx].get('role') == 'user': | |
| q = chat_hist[idx].get('content') | |
| r = chat_hist[idx+1].get('content') if idx+1 < len(chat_hist) else None | |
| combined.append({ | |
| 'source': source, | |
| 'query_type': 'chat', | |
| 'query': q, | |
| 'result': r, | |
| 'timestamp': chat_hist[idx].get('timestamp') | |
| }) | |
| if combined: | |
| df_history = pd.DataFrame(combined) | |
| # ensure timestamp column is datetime | |
| if not pd.api.types.is_datetime64_any_dtype(df_history['timestamp']): | |
| df_history['timestamp'] = pd.to_datetime(df_history['timestamp']) | |
| # sort latest first | |
| df_history = df_history.sort_values('timestamp', ascending=False) | |
| st.table(df_history) | |
| else: | |
| st.write("No history yet.") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| with right_col: | |
| # Use selected_sources from left data selector | |
| data_sources = st.session_state.get('selected_sources', []) | |
| csv_files = st.session_state.get('uploaded_csvs', []) | |
| selected_source = data_sources[0] if data_sources else None | |
| # Chat history per source (only if a source is selected) | |
| if 'chat_histories' not in st.session_state: | |
| st.session_state['chat_histories'] = {} | |
| # Initialize past conversations container | |
| if 'all_conversations' not in st.session_state: | |
| st.session_state['all_conversations'] = {} | |
| # Only proceed with chat if a data source is selected | |
| if selected_source is not None: | |
| if selected_source not in st.session_state['chat_histories']: | |
| st.session_state['chat_histories'][selected_source] = [] | |
| if selected_source not in st.session_state['all_conversations']: | |
| st.session_state['all_conversations'][selected_source] = [] | |
| chat_history = st.session_state['chat_histories'][selected_source] | |
| # Only show chat interface if a data source is selected | |
| if selected_source is not None: | |
| container = st.container(height=700, border=True) | |
| # Align New Conversation button top-right | |
| with container: | |
| cols = st.columns([2, 1]) | |
| with cols[0]: | |
| st.markdown("**Ask TalkToData**") | |
| if cols[1].button("New Chat", key=f"new_conv_{selected_source}"): | |
| if chat_history: | |
| conv = chat_history.copy() | |
| ts = conv[0].get('timestamp', datetime.now()) | |
| st.session_state['all_conversations'][selected_source].append({'messages':conv, 'timestamp':ts}) | |
| st.session_state['chat_histories'][selected_source] = [] | |
| st.experimental_rerun() | |
| # Display chat messages | |
| chat_history = st.session_state['chat_histories'][selected_source] | |
| # Welcome message for new chat | |
| if not chat_history: | |
| container.chat_message("assistant").write("👋 Hello! Welcome to TalkToData. Ask any question about your data to get started.") | |
| for turn in chat_history: | |
| role = turn.get('role', '') | |
| content = turn.get('content', '') | |
| if role == 'user': | |
| container.chat_message("user").write(content) | |
| else: | |
| container.chat_message("assistant").write(content) | |
| # Chat input | |
| user_input = st.chat_input(f"Ask a question about {selected_source}...") | |
| else: | |
| # Placeholder to maintain layout | |
| st.container(height=700, border=True) | |
| user_input = None | |
| if user_input: | |
| chat_history.append({"role": "user", "content": user_input, "timestamp": datetime.now()}) | |
| with container.chat_message("user"): | |
| st.write(user_input) | |
| # Answer logic | |
| with container.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| if selected_source == db_name: | |
| # Handle /sql and /plot commands | |
| if user_input.strip().lower().startswith('/sql'): | |
| sql = user_input[len('/sql'):].strip() | |
| try: | |
| df = pd.read_sql_query(sql, sqlite3.connect(DB_PATH)) | |
| st.write(f"```sql\n{sql}\n```") | |
| st.table(df) | |
| chat_history.append({"role": "assistant", "content": f"```sql\n{sql}\n```", "timestamp": datetime.now()}) | |
| except Exception as e: | |
| err = f"SQL Error: {e}" | |
| st.error(err) | |
| chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()}) | |
| elif user_input.strip().lower().startswith('/plot'): | |
| sql = user_input[len('/plot'):].strip() | |
| try: | |
| tool = PlotSQLTool() | |
| md = tool._run(sql) | |
| st.markdown(md) | |
| m = re.search(r'!\[.*\]\((.*?)\)', md) | |
| if m: | |
| st.image(m.group(1)) | |
| chat_history.append({"role": "assistant", "content": md, "timestamp": datetime.now()}) | |
| except Exception as e: | |
| err = f"Plot Error: {e}" | |
| st.error(err) | |
| chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()}) | |
| else: | |
| # Use SQL agent as before | |
| state['question'] = user_input | |
| try: | |
| for step in agent.graph.stream(state, stream_mode="updates"): | |
| step_name, step_details = next(iter(step.items())) | |
| if step_name == 'generate_sql': | |
| with st.expander("SQL Generated", expanded=False): | |
| st.markdown(f"```sql\n{step_details.get('sql_query', '')}\n```") | |
| elif step_name == 'execute_sql': | |
| with st.expander("SQL Result", expanded=False): | |
| st.table(step_details.get('sql_result', pd.DataFrame())) | |
| elif step_name == 'generate_answer': | |
| st.write(step_details.get('answer', '')) | |
| chat_history.append({"role": "assistant", "content": step_details.get('answer', ''), "timestamp": datetime.now()}) | |
| elif step_name == 'render_visualization': | |
| try: | |
| visualization_output = step_details.get('visualization_output') | |
| if visualization_output and os.path.exists(visualization_output): | |
| st.image(visualization_output) | |
| else: | |
| print("No visualization was generated for this query.") | |
| except Exception as e: | |
| print(f"Could not display visualization: {str(e)}") | |
| except Exception as e: | |
| err = f"SQL Agent Error: {e}" | |
| print(err) | |
| chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()}) | |
| else: | |
| # Use DataFrame agent for selected CSV | |
| csv_file = next((csv for csv in csv_files if csv['name'] == selected_source), None) | |
| if csv_file: | |
| if 'csv_agents' not in st.session_state: | |
| st.session_state['csv_agents'] = {} | |
| if selected_source not in st.session_state['csv_agents']: | |
| st.session_state['csv_agents'][selected_source] = get_dataframe_agent(csv_file['df']) | |
| agent = st.session_state['csv_agents'][selected_source] | |
| try: | |
| response = agent.invoke(user_input) | |
| answer = response["output"] if isinstance(response, dict) and "output" in response else str(response) | |
| except Exception as e: | |
| answer = f"CSV Agent Error: {e}" | |
| st.write(answer) | |
| chat_history.append({"role": "assistant", "content": answer, "timestamp": datetime.now()}) | |
| # Refresh to update History immediately | |
| # st.experimental_rerun() | |
| # Past Conversations Panel | |
| with st.container(height=200): | |
| st.markdown("**Recent Conversations**") | |
| # Flatten and sort conversations by most recent first | |
| entries = [] | |
| for source, convs in st.session_state.get('all_conversations', {}).items(): | |
| for conv in convs: | |
| entries.append((source, conv)) | |
| entries = sorted(entries, key=lambda x: x[1]['timestamp'], reverse=True) | |
| for source, conv in entries: | |
| label = conv['timestamp'].strftime("%Y-%m-%d %H:%M:%S") | |
| with st.expander(f"{source} - {label}", expanded=False): | |
| for msg in conv['messages']: | |
| if msg.get('role') == 'user': | |
| st.chat_message('user').write(msg.get('content')) | |
| else: | |
| st.chat_message('assistant').write(msg.get('content')) | |