aryachakraborty's picture
Upload 4 files
1bb8aca verified
import streamlit as st
from My_SQL_Connection import database_details, tables_in_this_DB, printing_tables, create_table_command,retrieve_result
from streamlit_option_menu import option_menu
from model_functions import LOAD_GEMMA,DeepSeekCoder,LOAD_GEMMA_GGUF
import torch
import mysql.connector
user_name = 'arya & Shritama'
st.set_page_config(page_title="My SQL Explorer", page_icon="🔍", layout="centered", initial_sidebar_state="expanded")
if 'localhost' not in st.session_state:
st.session_state.localhost = ''
st.session_state.user = ''
st.session_state.password = ''
st.session_state.table_commands = """ """
with st.sidebar:
selected = option_menu("Querio Lingua", ["Log In", 'main functionalities','Chat with AI'],
icons=['person-circle', 'info-circle-fill', 'chat-fill'], menu_icon="cast", default_index=0,
styles={
"container": {"padding": "5!important","background-color":'black'},
"icon": {"color": "white", "font-size": "23px"},
"nav-link": {"color":"white","font-size": "20px", "text-align": "left", "margin":"0px", "--hover-color": "gray"},
"nav-link-selected": {"background-color": "#1B2135"},})
if selected == 'Log In':
st.subheader('Please Log in into your MySql server by providing the following details ~ ')
st.session_state.localhost = st.text_input("what is your host, (localhost if in local) or give the url", 'localhost',help='host')
st.session_state.user = st.text_input("what is your user name (usually root)", 'root')
st.session_state.password = st.text_input('Password', type='password')
elif selected == 'main functionalities':
st.subheader('welcome to our MY SQL Database Explorer ~ ')
if st.button('All your databases ~ '):
try:
db, l = database_details(st.session_state.localhost, st.session_state.user, st.session_state.password)
st.table(db)
except mysql.connector.Error as e:
error_code = e.errno
st.warning(f"An error occurred (Error Code: {error_code}). Please check your login details.")
st.subheader('Now we will see details of any database~ ')
st.session_state.db_name = st.text_input('Which Database you want')
if st.button('All tables present in that particular database'):
if not st.session_state.db_name:
st.warning('Input database name first')
else:
try:
tables, l = tables_in_this_DB(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
st.write(f'There is only {l} tables present in this database')
st.markdown(f"**:rainbow[{tables[0][0]}]**")
except mysql.connector.Error as e:
st.warning("An error occured. Please select the correct database from the above list or check that you are loged in into your server.")
st.subheader('check out tables~ ')
if st.button('Print the tables~'):
try:
tables_data = printing_tables(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
for table_name, table_data in tables_data.items():
st.write(f"Table: {table_name}")
st.table(table_data)
except mysql.connector.Error as e:
st.warning("An error occured. Please check that you have selected a database or have loged in into your server.")
st.subheader('Retrieve the CREATE TABLE Statements')
statement_options = st.radio("Choose the Context option for chat",["Generate the Context for chat AI based on your tables",
"Give custom chat context"])
if statement_options == 'Generate the Context for chat AI based on your tables':
if st.button('Generate context'):
try:
statements = create_table_command(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
for table_name, table_statements in statements.items():
st.write(f'{table_name}')
st.session_state.table_commands = table_statements
st.code(table_statements)
except mysql.connector.Error as e:
st.warning('An error occured. Please check that you have selected a database or have loged in into your server.')
elif statement_options == 'Give custom chat context':
context = st.text_area("Paste your context here (Usually the tables schema)")
st.session_state.table_commands = context
elif selected == 'Chat with AI':
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
if "input" not in st.session_state:
st.session_state["input"] = ""
if "stored_session" not in st.session_state:
st.session_state["stored_session"] = []
def get_text():
"""
Get the user input text.
Returns:
(str): The text entered by the user
"""
input_text = st.text_input("You: ", st.session_state["input"], key="input",
placeholder="Your AI assistant here! Ask me anything ...",
label_visibility='hidden')
return input_text
def new_chat():
"""
Clears session state and starts a new chat.
"""
save = []
for i in range(len(st.session_state['generated'])-1, -1, -1):
save.append("User:" + st.session_state["past"][i])
save.append("Bot:" + st.session_state["generated"][i])
st.session_state["stored_session"].append(save)
st.session_state["generated"] = []
st.session_state["past"] = []
st.session_state["input"] = ""
#with st.sidebar.expander("Available Fine Tuned Models", expanded=False):
MODEL = st.sidebar.selectbox(label='Available Fine Tuned Models', options=['GEMMA-2B','Gemma-GGUF', 'DeepSeekCoder 1.3B'])
st.sidebar.warning('Load only one model at a time as it loads the model into cache so it may cause cache overload',icon="⚠️")
st.title("Querio Lingua 🤖")
st.markdown("Your own SQL code helper⭐")
st.markdown(" Powered by GEMMA & DeepSeek🚀")
st.sidebar.button("New Chat", on_click = new_chat, type='primary')
user_input = get_text()
if user_input:
if MODEL == 'GEMMA-2B':
gemma_tokenizer,gemma_model = LOAD_GEMMA()
device = torch.device("cpu")
alpeca_prompt = f"""Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables.
### Instruction: {user_input}. ### Input: {st.session_state.table_commands}
### Response:
"""
with st.status('Generating Result',expanded=False) as status:
inputs = gemma_tokenizer([alpeca_prompt], return_tensors="pt").to(device)
outputs = gemma_model.generate(**inputs, max_new_tokens=30)
output = gemma_tokenizer.decode(outputs[0], skip_special_tokens=True)
response_portion = output.split("### Response:")[-1].strip()
st.session_state.past.append(user_input)
st.session_state.generated.append(response_portion)
status.update(label="Result Generated!", state="complete", expanded=False)
elif MODEL == 'Gemma-GGUF':
with st.status('Generating Result',expanded=False) as status:
response = LOAD_GEMMA_GGUF(user_input,st.session_state.table_commands)
response_portion = response.split("### Response:")[-1].strip()
st.session_state.past.append(user_input)
st.session_state.generated.append(response_portion)
elif MODEL == 'DeepSeekCoder 1.3B':
with st.status('Generating Result',expanded=False) as status:
try:
response_portion = DeepSeekCoder(user_input,st.session_state.table_commands)
final_output = response_portion + f"\n {retrieve_result(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name,response_portion)}"
st.session_state.past.append(user_input)
st.session_state.generated.append(final_output)
print(final_output)
except mysql.connector.Error as e:
st.session_state.past.append(user_input)
st.session_state.generated.append(response_portion + '{Query not executable}')
status.update(label="Result Generated!", state="complete", expanded=False)
download_str = []
# Display the conversation history using an expander, and allow the user to download it
with st.expander("Conversation", expanded=True):
for i in range(len(st.session_state['generated'])-1, -1, -1):
st.info(st.session_state["past"][i],icon="🧐")
st.success(st.session_state["generated"][i], icon="🤖")
download_str.append(st.session_state["past"][i])
download_str.append(st.session_state["generated"][i])
# Can throw error - requires fix
download_str = '\n'.join(download_str)
if download_str:
st.download_button('Download',download_str)
# Display stored conversation sessions in the sidebar
for i, sublist in enumerate(st.session_state.stored_session):
with st.sidebar.expander(label= f"Conversation-Session:{i}"):
st.write(sublist)
# Allow the user to clear all stored conversation sessions
if st.session_state.stored_session:
if st.sidebar.button("Clear-all"):
del st.session_state.stored_session