ConversDB / app.py
space-runner's picture
Upload 11 files
aa1ee73 verified
import google.generativeai as genai
from dotenv import load_dotenv
import streamlit as st
import os
import sqlite3
import pandas as pd
import io
from langchain.prompts import PromptTemplate
from utils.extarct_db import extract_name_and_colums , read_excel_query , read_sql_query
from typing import List
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
load_dotenv()
key = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=key)
model=genai.GenerativeModel('gemini-pro')
# Function to load gemini model
def get_gemini_response(question: str, table_name: str, column_names: List[str]) -> str:
# Convert column_names list to a comma-separated string
columns_str = ', '.join(column_names)
# Defining the prompt
prompt = PromptTemplate(
input_variables=["table_name", "columns"],
template="""
You are an expert in converting English questions to SQL query!
The SQL database has the name {table_name} and has the following columns -
{columns}\n\nFor example,\nExample 1 - How many entries of records are present?,
the SQL command will be something like this SELECT COUNT(*) FROM {table_name} ;
\nExample 2 - Tell me all the students studying in Data Science class?,
the SQL command will be something like this SELECT * FROM {table_name}
also the sql code should not have ``` in beginning or end and sql word in output
"""
)
# Format the prompt with the table_name and columns_str
formatted_prompt = prompt.format(table_name=table_name, columns=columns_str)
# Assuming genai is correctly imported and configured
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content([formatted_prompt, question])
return response.text
st.header("ConverseDB")
uploaded_file = st.sidebar.file_uploader("Choose a file", type=None)
if uploaded_file is not None:
file_type = uploaded_file.type
allowed_file_extensions = ["application/octet-stream", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"]
if file_type in allowed_file_extensions:
print(uploaded_file.name)
question=st.text_input("Input Prompt: " , key="input")
submit = st.button("Query")
if file_type == "application/octet-stream":
file_path = os.path.join(os.getcwd(), uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
db_info = extract_name_and_colums(uploaded_file.name)
table_name = db_info['table_name'][0]
column_names = db_info['colum_names']['STUDENT']
print(column_names)
if submit:
response = get_gemini_response(question, table_name, column_names)
print(response)
response = read_sql_query(response, file_path)
# print(response)
formatted_response=model.generate_content(f"Format this {response} in the table format")
response_text = formatted_response.candidates[0].content.parts[0].text
# Splitting the response text into lines
lines = response_text.strip().split('\n')
print(lines)
# Extracting column names and data
columns = [col.strip() for col in lines[0].split('|') if col.strip()]
data = [dict(zip(columns, [item.strip() for item in line.split('|') if item.strip()])) for line in lines[2:]]
print(columns,data)
# Creating DataFrame
df = pd.DataFrame(data)
st.subheader("The Response is ")
st.table(df)
try:
# Example chart with Plotly
fig = px.bar(df, x=columns[0], y=columns[1], title="Visualize")
st.plotly_chart(fig)
except:
st.success("No Enough data available to do analysis")
if file_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
# converting excel file to db format then deleting excel file
db_name = "data.db"
file_path = os.path.join(os.getcwd(), uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
df = pd.read_excel(file_path)
conn = sqlite3.connect(db_name)
df.to_sql("excel_data", conn, index=False, if_exists="replace")
conn.close()
# os.remove(file_path)
# extracting DB Name & colum names
db_info = extract_name_and_colums(db_name)
table_name = db_info['table_name'][0]
column_names = db_info['colum_names']
print(column_names)
if submit:
response = get_gemini_response(question, table_name, column_names)
print(response)
response = read_sql_query(response, db_name)
# print(response)
formatted_response=model.generate_content(f"Format this {response} in the table format")
response_text = formatted_response.candidates[0].content.parts[0].text
# Splitting the response text into lines
lines = response_text.strip().split('\n')
# Extracting column names and data
columns = [col.strip() for col in lines[0].split('|') if col.strip()]
data = [dict(zip(columns, [item.strip() for item in line.split('|') if item.strip()])) for line in lines[2:]]
# Creating DataFrame
df = pd.DataFrame(data)
st.subheader("The Response is ")
st.table(df)
try:
# Example chart with Plotly
fig = px.bar(df, x=columns[0], y=columns[1], title="Visualize")
st.plotly_chart(fig)
except:
st.success("No Enough data available to do analysis")
else:
st.error("File type is not allowed. Please upload a .db or .xlsx file.")