Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import sqlite3 | |
| import requests | |
| import hashlib | |
| import pandas as pd | |
| # Initialize SQLite database | |
| def init_db(): | |
| conn = sqlite3.connect('users.db') | |
| c = conn.cursor() | |
| c.execute('''CREATE TABLE IF NOT EXISTS users | |
| (username TEXT PRIMARY KEY, password TEXT)''') | |
| conn.commit() | |
| conn.close() | |
| # Hash password | |
| def hash_password(password: str) -> str: | |
| return hashlib.sha256(password.encode()).hexdigest() | |
| # User authentication | |
| def authenticate_user(username: str, password: str) -> bool: | |
| conn = sqlite3.connect('users.db') | |
| c = conn.cursor() | |
| c.execute('SELECT password FROM users WHERE username=?', (username,)) | |
| result = c.fetchone() | |
| conn.close() | |
| return result and result[0] == hash_password(password) | |
| # User registration | |
| def register_user(username: str, password: str) -> bool: | |
| try: | |
| conn = sqlite3.connect('users.db') | |
| c = conn.cursor() | |
| c.execute('INSERT INTO users VALUES (?, ?)', (username, hash_password(password))) | |
| conn.commit() | |
| conn.close() | |
| return True | |
| except sqlite3.IntegrityError: | |
| return False | |
| # Initialize session state | |
| def init_session_state(): | |
| if 'logged_in' not in st.session_state: | |
| st.session_state.logged_in = False | |
| if 'current_page' not in st.session_state: | |
| st.session_state.current_page = 'login' | |
| if 'username' not in st.session_state: | |
| st.session_state.username = None | |
| if 'db_connected' not in st.session_state: | |
| st.session_state.db_connected = False | |
| # Login/Signup page | |
| def login_page(): | |
| st.set_page_config(page_title="Talk2SQLπ¨πΌβπ»π’", layout="wide") | |
| st.header('Talk2SQLπ¨πΌβπ»π’') | |
| st.title('Login / Sign Up') | |
| tab1, tab2 = st.tabs(['Login', 'Sign Up']) | |
| with tab1: | |
| with st.form('login_form'): | |
| username = st.text_input('Username') | |
| password = st.text_input('Password', type='password') | |
| submit = st.form_submit_button('Login') | |
| if submit: | |
| if authenticate_user(username, password): | |
| st.session_state.logged_in = True | |
| st.session_state.username = username | |
| st.session_state.current_page = 'db_connection' | |
| st.rerun() | |
| else: | |
| st.error('Invalid username or password') | |
| with tab2: | |
| with st.form('signup_form'): | |
| new_username = st.text_input('Username') | |
| new_password = st.text_input('Password', type='password') | |
| confirm_password = st.text_input('Confirm Password', type='password') | |
| submit = st.form_submit_button('Sign Up') | |
| if submit: | |
| if new_password != confirm_password: | |
| st.error('Passwords do not match') | |
| elif register_user(new_username, new_password): | |
| st.success('Registration successful! Please login.') | |
| else: | |
| st.error('Username already exists') | |
| # Database connection page | |
| def db_connection_page(): | |
| st.set_page_config(page_title="Talk2SQLπ¨πΌβπ»π’", layout="wide") | |
| st.header('Talk2SQLπ¨πΌβπ»π’') | |
| st.title('Database Connection') | |
| # Sidebar content | |
| with st.sidebar: | |
| st.header("Sample Data") | |
| # Sample connection string | |
| st.subheader("Sample Connection String") | |
| st.sidebar.subheader("Sample Connection String") | |
| st.sidebar.code("mysql+pymysql://admin:9522359448@mydatabase.cf8u2cy0a4h6.us-east-1.rds.amazonaws.com:3306/mydb") | |
| st.sidebar.subheader("Sample Table") | |
| sample_data = pd.DataFrame({ | |
| "id": [1, 2, 3, 4], | |
| "first_name": ["John", "Jane", "Tom", "Jerry"], | |
| "last_name": ["Doe", "Doe", "Smith", "Jones"], | |
| "email": ["johnD@abc.com", "JaneD@abc.com", "toms@abc.com", "Jerry@abc.com"], | |
| "hire_date": ["2020-01-01", "2020-05-01", "2020-03-01", "2020-02-01"], | |
| "salary": [50000, 60000, 70000, 80000] | |
| }) | |
| st.sidebar.dataframe(sample_data) | |
| # Sample questions | |
| st.subheader("Sample Questions") | |
| questions = [ | |
| "What is the email of John?", | |
| "What is the lastname of Tom?", | |
| "Hiredate of the Jerry?" | |
| ] | |
| for q in questions: | |
| st.markdown(f"- {q}") | |
| # Logout button | |
| st.divider() | |
| if st.button("Logout", type="primary"): | |
| logout() | |
| # Main content | |
| db_options = ["MySQL", "PostgreSQL"] | |
| db_type = st.selectbox("Select Database Type", db_options) | |
| placeholder_text = "" | |
| if db_type == "PostgreSQL": | |
| placeholder_text = "postgresql://user:password@host:port/database" | |
| elif db_type == "MySQL": | |
| placeholder_text = "mysql+pymysql://user:password@host:port/database" | |
| with st.form('connection_form'): | |
| connection_string = st.text_input('Connection String', placeholder=placeholder_text, disabled=not db_type) | |
| submit = st.form_submit_button('Connect') | |
| if submit and connection_string: | |
| try: | |
| response = requests.post( | |
| 'http://localhost:8000/api/v1/setup-connection', | |
| json={'connection_string': connection_string} | |
| ) | |
| if response.status_code == 200: | |
| st.success('Database connected successfully!') | |
| st.session_state.db_connected = True | |
| st.session_state.current_page = 'chat' | |
| st.rerun() | |
| else: | |
| st.error(f'Connection failed: {response.text}') | |
| except requests.RequestException as e: | |
| st.error(f'Error connecting to backend: {str(e)}') | |
| # Chat interface page | |
| def chat_page(): | |
| st.set_page_config(page_title="Talk2SQLπ¨πΌβπ»π’", layout="wide") | |
| st.title('Chat Interface') | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| query = st.chat_input("Enter your query") | |
| if query: | |
| st.session_state.chat_history.append({"role": "user", "content": query}) | |
| try: | |
| response = requests.post( | |
| 'http://localhost:8000/api/v1/query', | |
| json={'query': query} | |
| ) | |
| if response.status_code == 200: | |
| result = response.json().get("result", "No result") | |
| st.session_state.chat_history.append({"role": "assistant", "content": result}) | |
| st.rerun() | |
| else: | |
| st.error(f'Query failed: {response.text}') | |
| except requests.RequestException as e: | |
| st.error(f'Error connecting to backend: {str(e)}') | |
| if st.button("End Chat"): | |
| st.session_state.current_page = 'db_connection' | |
| st.rerun() | |
| # Main app | |
| def main(): | |
| init_db() | |
| init_session_state() | |
| if not st.session_state.logged_in: | |
| login_page() | |
| elif st.session_state.current_page == 'db_connection': | |
| db_connection_page() | |
| elif st.session_state.current_page == 'chat': | |
| if not st.session_state.db_connected: | |
| st.error('Database not connected. Redirecting to Database Connection page') | |
| st.session_state.current_page = 'db_connection' | |
| st.rerun() | |
| chat_page() | |
| def logout(): | |
| st.session_state.logged_in = False | |
| st.session_state.username = None | |
| st.session_state.current_page = 'login' | |
| st.session_state.db_connected = False | |
| st.rerun() | |
| if __name__ == '__main__': | |
| main() |