Spaces:
Running
Running
import streamlit as st | |
from langchain_community.utilities import SQLDatabase | |
from langchain.chat_models import ChatOpenAI | |
from langchain.agents import create_sql_agent | |
from langchain_groq import ChatGroq | |
from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
from dotenv import load_dotenv | |
import tempfile | |
import sqlite3 | |
import pandas as pd | |
import os | |
load_dotenv() | |
def is_valid_sqlite(file_path): | |
try: | |
with sqlite3.connect(file_path) as conn: | |
conn.execute("SELECT name FROM sqlite_master LIMIT 1;") | |
return True | |
except sqlite3.DatabaseError: | |
return False | |
def text_to_sql(query: str, db_path: str, llm_provider: str, api_key: str, model_name: str): | |
try: | |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
if llm_provider == 'OPENAI': | |
llm = ChatOpenAI(api_key=api_key, model=model_name) | |
elif llm_provider == 'OPEN_ROUTER': | |
llm = ChatOpenAI(api_key=api_key, base_url='https://openrouter.ai/api/v1', model=model_name) | |
elif llm_provider == 'GROQ': | |
llm = ChatGroq(api_key=api_key, model=model_name) | |
else: | |
return "Unsupported LLM provider selected." | |
toolkit = SQLDatabaseToolkit(llm=llm, db=db) | |
db_chain = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True) | |
return db_chain.run(query) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def show_tables_as_df(db_path): | |
conn = sqlite3.connect(db_path) | |
cursor = conn.cursor() | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
tables = cursor.fetchall() | |
if tables: | |
for table_name in tables: | |
table = table_name[0] | |
st.subheader(f"Table: {table}") | |
df = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 10", conn) | |
st.dataframe(df) | |
else: | |
st.write("No tables found in database.") | |
conn.close() | |
# Streamlit UI | |
st.title('ποΈ Chat with YOUR SQLite Database') | |
st.write("Upload your `.db` file and interact using natural language queries powered by LLMs.") | |
uploaded_file = st.file_uploader("Upload your `.db` file", type=["db"], accept_multiple_files=False) | |
llm_provider = st.radio("Choose LLM Provider", options=['OPEN_ROUTER', 'GROQ', 'OPENAI']) | |
model_name = st.text_input("Enter the Model Name", value='nousresearch/deephermes-3-mistral-24b-preview:free') | |
api_key = st.text_input("Enter Your API Key", type="password") | |
query = st.text_area("Enter Your Query") | |
if uploaded_file: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".db", dir="/tmp") as tmpfile: | |
tmpfile.write(uploaded_file.read()) | |
tmp_db_path = tmpfile.name | |
if not is_valid_sqlite(tmp_db_path): | |
st.error("Uploaded file is not a valid SQLite database.") | |
else: | |
st.success(f"Valid database `{uploaded_file.name}` uploaded!") | |
st.info("Displaying first 10 rows from each table:") | |
show_tables_as_df(tmp_db_path) | |
if st.button("RUN Query"): | |
if not api_key or not model_name: | |
st.error("Please provide API key and model name.") | |
elif not query.strip(): | |
st.error("Please enter a query.") | |
else: | |
st.info(f"Running query on `{uploaded_file.name}`...") | |
result = text_to_sql(query, tmp_db_path, llm_provider, api_key, model_name) | |
st.success("Query Result:") | |
st.write(result) | |
else: | |
st.info("Please upload a SQLite `.db` file to begin.") | |