File size: 6,866 Bytes
a71dd13
 
 
 
 
 
 
 
 
 
 
aa1ee73
a71dd13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1ee73
a71dd13
 
 
 
 
 
aa1ee73
 
 
a71dd13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1ee73
 
 
 
 
 
a71dd13
aa1ee73
a71dd13
 
 
 
 
 
aa1ee73
a71dd13
aa1ee73
 
 
 
 
 
 
 
 
 
 
a71dd13
 
aa1ee73
a71dd13
aa1ee73
 
 
a71dd13
aa1ee73
a71dd13
 
aa1ee73
 
 
a71dd13
aa1ee73
a71dd13
aa1ee73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a71dd13
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import google.generativeai as genai
from dotenv import load_dotenv
import streamlit as st
import os
import sqlite3
import pandas as pd
import io
from langchain.prompts import PromptTemplate
from utils.extarct_db import extract_name_and_colums , read_excel_query , read_sql_query
from typing import List
import numpy as np
import plotly.express as px

import matplotlib.pyplot as plt


load_dotenv()

key  = os.getenv("GOOGLE_API_KEY")

genai.configure(api_key=key)

model=genai.GenerativeModel('gemini-pro')



# Function to load gemini model


def get_gemini_response(question: str, table_name: str, column_names: List[str]) -> str:
    # Convert column_names list to a comma-separated string
    columns_str = ', '.join(column_names)
    
    # Defining the prompt
    prompt = PromptTemplate(
        input_variables=["table_name", "columns"],
        template=""" 

        You are an expert in converting English questions to SQL query!

        The SQL database has the name {table_name} and has the following columns - 

        {columns}\n\nFor example,\nExample 1 - How many entries of records are present?,

        the SQL command will be something like this SELECT COUNT(*) FROM {table_name} ;

        \nExample 2 - Tell me all the students studying in Data Science class?,

        the SQL command will be something like this SELECT * FROM {table_name}

        also the sql code should not have ``` in beginning or end and sql word in output 

        

        """
    )
    
    # Format the prompt with the table_name and columns_str
    formatted_prompt = prompt.format(table_name=table_name, columns=columns_str)
    
    # Assuming genai is correctly imported and configured
    model = genai.GenerativeModel('gemini-pro')
    response = model.generate_content([formatted_prompt, question])
    return response.text
    
st.header("ConverseDB")
uploaded_file = st.sidebar.file_uploader("Choose a file", type=None)


if uploaded_file is not None:

    
    file_type = uploaded_file.type
    allowed_file_extensions = ["application/octet-stream", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"]
    
    
    if file_type in allowed_file_extensions:
        
        print(uploaded_file.name)
        question=st.text_input("Input Prompt: " , key="input")
        submit = st.button("Query") 
        
            
        if file_type == "application/octet-stream":
            
            
            file_path = os.path.join(os.getcwd(), uploaded_file.name)
            with open(file_path, "wb") as f:
                f.write(uploaded_file.getbuffer())
                
                            
            db_info = extract_name_and_colums(uploaded_file.name)
            table_name = db_info['table_name'][0]
            column_names = db_info['colum_names']['STUDENT']
            
            print(column_names)

            if submit:
                response = get_gemini_response(question, table_name, column_names)
                print(response)
                            
                response = read_sql_query(response, file_path)

                # print(response)
                formatted_response=model.generate_content(f"Format this {response} in the table format")
                response_text = formatted_response.candidates[0].content.parts[0].text
                # Splitting the response text into lines
                lines = response_text.strip().split('\n')
                
                print(lines)

                # Extracting column names and data
                columns = [col.strip() for col in lines[0].split('|') if col.strip()]
                data = [dict(zip(columns, [item.strip() for item in line.split('|') if item.strip()])) for line in lines[2:]]

                print(columns,data)
                # Creating DataFrame
                df = pd.DataFrame(data)
                st.subheader("The Response is ")
                st.table(df)
                
                try:
                    # Example chart with Plotly
                    fig = px.bar(df, x=columns[0], y=columns[1], title="Visualize")
                    st.plotly_chart(fig)
                except:
                    st.success("No Enough data available to do analysis")

        
                
                 
                
                
                
                
        if file_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
            
            # converting excel file to db format then deleting excel file
            db_name = "data.db"
            file_path = os.path.join(os.getcwd(), uploaded_file.name)
            with open(file_path, "wb") as f:
                f.write(uploaded_file.getbuffer())

            df = pd.read_excel(file_path)
            conn = sqlite3.connect(db_name)
            df.to_sql("excel_data", conn, index=False, if_exists="replace")
            conn.close()
            # os.remove(file_path)
            
            
            # extracting DB Name & colum names
            
            db_info = extract_name_and_colums(db_name)
            table_name = db_info['table_name'][0]
            column_names = db_info['colum_names']
            
            print(column_names)
            
            
            if submit:
                response = get_gemini_response(question, table_name, column_names)
                print(response)
                            
                response = read_sql_query(response, db_name)

                # print(response)
                formatted_response=model.generate_content(f"Format this {response} in the table format")
                response_text = formatted_response.candidates[0].content.parts[0].text
                # Splitting the response text into lines
                lines = response_text.strip().split('\n')

                # Extracting column names and data
                columns = [col.strip() for col in lines[0].split('|') if col.strip()]
                data = [dict(zip(columns, [item.strip() for item in line.split('|') if item.strip()])) for line in lines[2:]]

                # Creating DataFrame
                df = pd.DataFrame(data)
                st.subheader("The Response is ")
                st.table(df)
                
                try:
                    # Example chart with Plotly
                    fig = px.bar(df, x=columns[0], y=columns[1], title="Visualize")
                    st.plotly_chart(fig)
                except:
                    st.success("No Enough data available to do analysis")

    else:
        st.error("File type is not allowed. Please upload a .db or .xlsx file.")