sql-chat / app.py
devin-ai's picture
Update app.py
c722913 verified
raw
history blame contribute delete
No virus
3.21 kB
from dotenv import load_dotenv
load_dotenv()
import streamlit as st
import google.generativeai as genai
import sqlite3
import os
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model=genai.GenerativeModel('gemini-pro')
prompt=[
"""
You are an expert in converting English questions to SQL query!
The SQL database has the name STUDENT and has the following columns - NAME, CLASS,
SECTION \n\nFor example,\nExample 1 - How many entries of records are present?,
the SQL command will be something like this SELECT COUNT(*) FROM STUDENT ;
\nExample 2 - Tell me all the students studying in Data Science class?,
the SQL command will be something like this SELECT * FROM STUDENT
where CLASS="Data Science";
\nExample 3-i marks should be greater 40 and atleast 2 person have score above 40 , retrive that class
the SQL command will be something like this SELECT CLASS FROM STUDENT GROUP BY CLASS HAVING COUNT(*) >= 2 AND AVG(MARKS) > 50;
\nExample 4-Find the names and marks of students in the Science class who have scored more than 60 marks.
SQL Command Example: SELECT NAME, MARKS FROM STUDENT WHERE CLASS='Science' AND MARKS > 60;
\nExample 5-List the classes with the highest average marks.
SQL Command Example: SELECT CLASS FROM STUDENT GROUP BY CLASS HAVING AVG(MARKS) = (SELECT MAX(AVG(MARKS)) FROM STUDENT GROUP BY CLASS);
also the sql code should not have ``` in beginning or end and sql word in output and
"""
]
#llm response
def gemini_sql_query(prompt,input):
response=model.generate_content([prompt[0],input])
return response.text
#dun to retrieve query from the sql database
def read_sql_query(sql,db):
conn=sqlite3.connect(db)
cursor=conn.cursor()
cursor.execute(sql)
rows=cursor.fetchall()
conn.commit()
conn.close()
for row in rows:
print(row)
return rows
st.set_page_config("DataChat: Explore Your Database")
st.header("DataChat: Chat With SQL Database")
question=st.text_input("Enter your input/question")
table_name = st.text_input("Enter the correct table name")
input=f"{question} in {table_name} table"
#save uploaded file
def save_uploaded_file(uploaded_file):
file_path = os.path.join(os.getcwd(), "uploaded.db")
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return file_path
# File uploader component
st.sidebar.header("Database Upload")
uploaded_file = st.sidebar.file_uploader("Upload SQLite Database", type=["db"])
if uploaded_file is not None:
# Save the uploaded file
db_path = save_uploaded_file(uploaded_file)
st.sidebar.success("Database uploaded successfully.")
submit=st.button("submit")
if submit and uploaded_file and input:
query=gemini_sql_query(prompt,input)
response=read_sql_query(query,db_path)
print(query)
col1, col2 = st.columns(2)
with col1:
st.header("Response:")
for row in response:
values = [str(value) for value in row]
st.write(*values)
with col2:
st.header("Generated SQL Query:")
with st.container(height=300):
st.code(query)