arjunanand13 commited on
Commit
7de61f5
1 Parent(s): 17122a5

Upload 19 files

Browse files
Antilia.pdf ADDED
Binary file (13.9 kB). View file
 
Biltmore_Estate.pdf ADDED
Binary file (25.9 kB). View file
 
Buckingham_palace.pdf ADDED
Binary file (16.6 kB). View file
 
Hearst_castle.pdf ADDED
Binary file (13.2 kB). View file
 
Istana_Nurul_Iman.pdf ADDED
Binary file (11.5 kB). View file
 
Palace_of_Versailes.pdf ADDED
Binary file (18.1 kB). View file
 
Taj_Mahal_palace.pdf ADDED
Binary file (16.2 kB). View file
 
Villa_Leopolda.pdf ADDED
Binary file (14.1 kB). View file
 
Villa_Les_Cedres.pdf ADDED
Binary file (11.5 kB). View file
 
White_House.pdf ADDED
Binary file (35.4 kB). View file
 
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageFilter
3
+ import os
4
+ from txt2sql_code3 import SQLPromptModel
5
+ from qa_bot_chatgpt import QAInfer
6
+ from gradio import Button
7
+ import time
8
+ import google.generativeai as genai
9
+
10
+ image = Image.open(os.path.join(os.path.abspath(''), "house_excel_sheet.png"))
11
+
12
+ def image_display(steps=0):
13
+ return image
14
+
15
+ query = None
16
+ rows = None
17
+ columns = None
18
+ user_choices = None
19
+ data = [
20
+ [1, "Buckingham Palace", 27, 34, 0.12, "London", "United Kingdom", 601, 920, 105000, "Buckingham_palace.pdf"],
21
+ [2, "White House", 355, 67, 0.08, "Washington D.C.", "United States", 601, 1527, 80000, "White_House.pdf"],
22
+ [3, "Taj Mahal Palace", 455, 76, 0.15, "Mumbai", "India", 795, 748, 67000, "Taj_Mahal_palace.pdf"],
23
+ [4, "Versailles Palace", 455, 45, 0.1, "Versailles", "France", 731, 1800, 145000, "Palace_of_Versailes.pdf"],
24
+ [5, "Villa Leopolda", 223, 21, 0.05, "Villefranche-sur-Mer", "France", 680, 6886, 65000, "Villa_Leopolda.pdf"],
25
+ [6, "Antilia", 455, 70, 0.46, "Mumbai", "India", 612, 2520, 179000, "Antilia.pdf"],
26
+ [7, "The Biltmore Estate", 544, 93, 0.2, "Asheville", "United States", 639, 2040, 50000, "Biltmore_Estate.pdf"],
27
+ [8, "Hearst Castle", 303, 57, 0.08, "San Simeon", "United States", 731, 1050, 71600, "Hearst_castle.pdf"],
28
+ [9, "Villa Les Cèdres", 489, 88, 0.09, "Saint-Jean-Cap-Ferrat", "France", 730, 1092, 100000, "Villa_Les_Cedres.pdf"],
29
+ [10, "Istana Nurul Iman", 350, 34, 7.46, "Bandar Seri Begawan", "Brunei", 670, 5403, 300000, "Istana_Nurul_Iman.pdf"]
30
+ ]
31
+ choices = [item[1] for item in data]
32
+
33
+
34
+
35
+ def execute_sql_query(input_prompt):
36
+ global query, rows, columns, user_choices
37
+
38
+ model_dir = "multi_table_demo/checkpoint-2600"
39
+ database = r"sql_pdf.db"
40
+ sql_model = SQLPromptModel(model_dir, database)
41
+
42
+ user_prompt = "Give complete details of properties in India"
43
+ for _ in range(3): # Retry logic, try 3 times
44
+ try:
45
+ table_schema = sql_model.fetch_table_schema("sql_pdf")
46
+ if table_schema:
47
+ if input_prompt.strip():
48
+ # query = sql_model.text2sql(table_schema, user_prompt, input_prompt)
49
+ # query = sql_model.text2sql_chatgpt(table_schema, user_prompt, input_prompt)
50
+ query = sql_model.text2sql_gemini(table_schema, user_prompt, input_prompt)
51
+ else:
52
+ # query = sql_model.text2sql(table_schema, user_prompt, user_prompt)
53
+ # query = sql_model.text2sql_chatgpt(table_schema, user_prompt, user_prompt)
54
+ query = sql_model.text2sql_gemini(table_schema, user_prompt, user_prompt)
55
+ rows, columns = sql_model.execute_query(query)
56
+ print(rows)
57
+ user_choices = []
58
+ for item in rows:
59
+ if isinstance(item[0], str):
60
+ user_choices.append(item[0])
61
+ else:
62
+ for val in item:
63
+ if isinstance(val, str):
64
+ user_choices.append(val)
65
+ break
66
+ return rows
67
+ else:
68
+ print("Table schema not found.")
69
+ return None
70
+ except Exception as e:
71
+ print(f"An error occurred: {e}")
72
+ print("Retrying...")
73
+ time.sleep(1) # Wait for 1 second before retrying
74
+ return None # Return None if all retries fail
75
+
76
+
77
+ # def qa_infer_interface(row,query_on_pdf):
78
+ # qa_infer=QAInfer()
79
+ # qa_infer.qa_infer(query,rows,columns)
80
+ user_choices = None
81
+ def update_choices(nothing):
82
+ print("callback called")
83
+ print("choices",choices)
84
+ print("user_choices",user_choices)
85
+ if user_choices:
86
+ examples = [[user_choices[0], "Structure of the property"], [user_choices[0], "Property History"], [user_choices[0], "How many floors does the property have"]]
87
+ return gr.Dropdown(choices=user_choices, label="Property Choice",info="List of all properties",interactive=True)#,examples
88
+ else:
89
+ return gr.Dropdown(choices=[], label="Property Choice",info="List of all properties",interactive=True)
90
+
91
+ def update_examples(nothing):
92
+ if user_choices:
93
+ examples = [[user_choices[0], "Structure of the property"], [user_choices[0], "Property History"], [user_choices[0], "How many floors does the property have"]]
94
+ return examples
95
+ else :
96
+ examples=[["","Structure of the property "],[ ""," Property History "] ,["", " How many floors does the property have"]],
97
+
98
+
99
+ def qa_infer_interface(property_choice, query_question):
100
+ qa_infer = QAInfer()
101
+ if not property_choice and user_choices:
102
+ property_choice = user_choices[0]
103
+ property_row = [row for row in data if row[1] == property_choice][0] # Find the row corresponding to the selected property
104
+ if not query_question:
105
+ query_question = "area"
106
+ retries = 3
107
+ while retries > 0:
108
+ try:
109
+ print(property_row)
110
+ # answer = qa_infer.qa_infer_interface(property_row, query_question) #-->chatgpt
111
+ answer = qa_infer.qa_infer_interface_gemini(property_row, query_question)
112
+ return answer
113
+ except Exception as e:
114
+ print(f"Error occurred while inferring QA: {e}")
115
+ retries -= 1
116
+ print("Failed to infer QA after 3 retries.")
117
+ return None
118
+
119
+
120
+
121
+ user_dropdown=gr.Dropdown(choices=[], label="Property Choice",info="List of all properties")
122
+ properties_text=gr.components.Textbox(lines=2,label="User Database Query",placeholder="Click on an query from 'examples' below or write your own query based on the database above. Default : 'Properties in India'")
123
+ interface_1_output=gr.Json(label="json")
124
+ stage2_examples=[["","Structure of the property "],[ ""," Property History "] ,["", " How many floors does the property have"]]
125
+ stage2_text=gr.components.Textbox(lines=2,label="Question on property",placeholder="Enter a question to know more about the properties , you can choose from one of the options below or write a own question Default: 'Area of the property'",)
126
+ stage2_output="text"
127
+
128
+ with gr.Blocks(title="House Excel Query") as demo:
129
+
130
+ gr.Markdown("# House Excel Query")
131
+
132
+ generated_image = image_display()
133
+ gr.Image(generated_image)
134
+
135
+ gr.Markdown("""### The database provided contains information about different properties, including their fundamental details. Additional specifics about each property are stored in associated PDF files, which are referenced in the "PDF" column. You have the capability to query this database using various criteria. When a query is initiated, the system generates SQL queries and extracts relevant rows from the database in the backend.
136
+ \n ### Once the properties are retrieved based on the query, you can utilize the user interface (UI) below to perform question answering (QA). Simply select a property from the list of returned properties and compose a question pertaining to that property. You will receive an answer based on the available information.""")
137
+
138
+ interface_1 = gr.Interface(
139
+ execute_sql_query,
140
+ inputs=properties_text,
141
+ # "textbox",
142
+ outputs=interface_1_output,
143
+ # live=True,
144
+ # cache_examples=["Give me all details of properties from India"],
145
+ examples=["Properties in France "," Properties greater than a acre","Properties with more than 400 bedrooms"],
146
+ )
147
+ # print(interface_1.input_components[0])
148
+
149
+
150
+ interface_2 = gr.Interface(
151
+ qa_infer_interface,
152
+ inputs=[user_dropdown,stage2_text],
153
+ # inputs=[gr.Dropdown.change(fn=update_choices),gr.components.Textbox(lines=2,label="Question on property",placeholder="Enter a question to know more about the properties")],
154
+ outputs=stage2_output,
155
+ # examples=stage2_examples,
156
+ # live=True,
157
+ # gr.Button("Next"),
158
+ # Button.click(next,value="Next"),
159
+
160
+ )
161
+
162
+ gr.Examples(["How many floors does the property have "," Total square feet of the property " ," Total area of the property"],inputs=stage2_text,outputs=stage2_output,fn=qa_infer_interface)
163
+
164
+
165
+ properties_text.change(update_choices,inputs=[properties_text],outputs=[user_dropdown])
166
+ interface_1_output.change(update_choices,inputs=[interface_1_output],outputs=[user_dropdown])
167
+
168
+
169
+ # user_dropdown.change(update_examples, inputs=[user_dropdown], outputs=[stage2_examples])
170
+
171
+ # properties_text.change(update_choices,inputs=[stage2_examples],outputs=[interface_2.examples])
172
+ # interface_1_output.change(update_choices,inputs=[stage2_examples],outputs=[interface_2.examples])
173
+
174
+ # user_dropdown.change(update_choices, inputs=[user_dropdown], outputs=[user_dropdown, interface_2])
175
+
176
+ # user_dropdown.change(fn=update_choices,inputs=[user_dropdown],outputs=[user_dropdown])
177
+
178
+ # with gr.Row():
179
+ # save_btn = gr.Button("Next")
180
+
181
+ # Button.click(next,value="Next",),
182
+
183
+ if __name__ == "__main__":
184
+ demo.launch(share=True)
185
+
186
+
187
+ ## download pdf buttons
188
+ ## upload pdf
189
+ ## dynamic selection
app2.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello! " + name + "api" + api
5
+
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()
house_excel_sheet.png ADDED
qa_bot_chatgpt.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from sqlite3 import Error
3
+ from PyPDF2 import PdfReader
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
5
+ import os
6
+ import torch
7
+ from huggingface_hub import login
8
+ import ast
9
+ from openai import OpenAI
10
+ import google.generativeai as genai
11
+ class QAInfer:
12
+ def __init__(self):
13
+ torch.cuda.empty_cache()
14
+
15
+ # self.chatgpt_client = OpenAI(api_key="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")
16
+
17
+ self.chatgpt_client = OpenAI(api_key="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")
18
+ self.genai = genai
19
+ self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA")
20
+ self.genai_model = genai.GenerativeModel('gemini-pro')
21
+
22
+ def extract_text_from_pdf(self, pdf_path):
23
+ """Extract text from a PDF file."""
24
+ reader = PdfReader(pdf_path)
25
+ text = ''
26
+ for page in reader.pages:
27
+ text += page.extract_text()
28
+ return text
29
+
30
+ def qa_infer(self, query, rows, col):
31
+ """QA inference function."""
32
+ print(query)
33
+
34
+ print(tuple(col))
35
+ file_index = -1
36
+
37
+ if "additional_info" not in col:
38
+ pass
39
+ else:
40
+ file_index = [i for i in range(len(col)) if col[i] == "additional_info"][0]
41
+ initiate_qa = input("\nDo you wish to ask questions about the properties [y/n]?: ").lower()
42
+
43
+ if initiate_qa in ['y', 'yes']:
44
+ for row in rows:
45
+ pdf_text = self.extract_text_from_pdf(row[file_index])
46
+ print("Extracted text from PDF", os.path.basename(row[file_index]))
47
+
48
+ while True:
49
+ user_question = input("\nWhat do you want to know about this property? (Press Enter to exit): ").strip()
50
+ if not user_question:
51
+ break
52
+ # Construct QA prompt directly here
53
+ question = user_question if user_question else "Who is lashkar e taiba"
54
+ prompt = f"""Below is a question and context, search the context to find the answer for the question and return the response ###question:{question} ###context:{pdf_text} ###response:"""
55
+
56
+ # Run the language model to generate a response
57
+ inputs = self.qa_tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512)
58
+ pipe = pipeline(
59
+ "text-generation",
60
+ model=self.qa_model,
61
+ tokenizer=self.qa_tokenizer,
62
+ torch_dtype=torch.bfloat16,
63
+ device_map="auto"
64
+ )
65
+ sequences = pipe(
66
+ prompt,
67
+ do_sample=True,
68
+ max_new_tokens=200,
69
+ temperature=0.7,
70
+ top_k=50,
71
+ top_p=0.95,
72
+ num_return_sequences=1,
73
+ )
74
+ answer = sequences[0]['generated_text']
75
+ print("Answer:", answer)
76
+ else:
77
+ continue_to_next = input("Do you want to continue with the next property? [y/n]: ").lower()
78
+ if continue_to_next != 'y':
79
+ return
80
+
81
+ def qa_infer_interface(self, row, query_question):
82
+ """This method is used for gradio interface only"""
83
+ file_path = row[-1] # Assuming the last element in row contains the PDF file path
84
+ pdf_text = self.extract_text_from_pdf(file_path)
85
+
86
+ # prompt = f"""Below is a question and context, search the context to find the answer for the question and return the response , if related answer cannot be found return "Answer not in the context" ###question:{query_question} ###context:{pdf_text} ###response:"""
87
+ prompt = f"""You have been provided with a question and a corresponding context. Your task is to search the context to find the answer to the question. If the answer is found, return the response. If the answer cannot be found in the context, please respond with "Answer not found in the context".
88
+
89
+ === Question ===
90
+ {query_question}
91
+
92
+ === Context ===
93
+ {pdf_text}
94
+
95
+ === Response ===
96
+ Try mostly to answer from given pdf , if related answer is not found return 'Information not present in the pdf' and below it provide something related to the question .Note: return only answer dont include terms like 'Response','###','Answer'"""
97
+
98
+
99
+
100
+ print(prompt)
101
+ completion = self.chatgpt_client.chat.completions.create(
102
+ model="gpt-3.5-turbo",
103
+ messages=[
104
+ {"role": "system", "content": "You are a expert PDF parser , go through the pdf and answer the question properly , if related answer is not found return 'Information not present in the pdf' and below it provide something related to the question"},
105
+ {"role": "user", "content": prompt }
106
+ ]
107
+ )
108
+ return completion.choices[0].message.content
109
+
110
+ def qa_infer_interface_gemini(self, row, query_question):
111
+ """This method is used for gradio interface only"""
112
+ file_path = row[-1] # Assuming the last element in row contains the PDF file path
113
+ pdf_text = self.extract_text_from_pdf(file_path)
114
+
115
+ # prompt = f"""Below is a question and context, search the context to find the answer for the question and return the response , if related answer cannot be found return "Answer not in the context" ###question:{query_question} ###context:{pdf_text} ###response:"""
116
+ prompt = f"""You have been provided with a question and a corresponding context. Your task is to search the context to find the answer to the question. If the answer is found, return the response. If the answer cannot be found in the context, please respond with "Answer not found in the context".
117
+
118
+ === Question ===
119
+ {query_question}
120
+
121
+ === Context ===
122
+ {pdf_text}
123
+
124
+ === Response ===
125
+ If related answer is not found return 'Information not present in the pdf' and below it provide something related to the question"""
126
+
127
+ print(prompt)
128
+ completion = self.genai_model.generate_content(prompt)
129
+ generated_answer=completion.text
130
+
131
+ return generated_answer
132
+
133
+
134
+ if __name__ == '__main__':
135
+ qa_infer = QAInfer()
136
+ query = 'SELECT * FROM sql_pdf WHERE country = "India" '
137
+ rows = [
138
+ (3, 'Taj Mahal Palace', 455, 76, 0.15, 'Mumbai', 'India', 795, 748, 67000, 'pdf_files/pdf/Taj_Mahal_palace.pdf'),
139
+ (6, 'Antilia', 455, 70, 0.46, 'Mumbai', 'India', 612, 2520, 179000, 'pdf_files/pdf/Antilia.pdf')
140
+ ]
141
+ col = [
142
+ "property_id", "name", "bed", "bath", "acre_lot", "city", "country",
143
+ "zip_code", "house_size", "price", "additional_info"
144
+ ]
145
+ qa_infer.qa_infer(query, rows, col)
146
+
147
+
148
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bitsandbytes
2
+ git+https://github.com/huggingface/transformers.git
3
+ git+https://github.com/huggingface/peft.git
4
+ git+https://github.com/huggingface/accelerate.git
5
+ datasets
6
+ evaluate
7
+ trl==0.7.1
8
+ jupyter
9
+ scipy
10
+ gradio
11
+ python-dotenv
12
+ openpyxl
13
+ PyPDF2
14
+ llama-parse
15
+ google-generativeai
sql.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from sqlite3 import Error
3
+ import csv
4
+ import pandas as pd
5
+ import os
6
+
7
+
8
+ def create_connection(db_file):
9
+ """ create a database connection to a database that resides
10
+ in the memory
11
+ """
12
+ conn = None;
13
+ try:
14
+ conn = sqlite3.connect(db_file)
15
+ return conn
16
+
17
+ except Error as e:
18
+ print(e)
19
+
20
+ return conn
21
+
22
+
23
+ def create_table(conn, create_table_sql, table_name):
24
+ """ create a table from the create_table_sql statement
25
+ :param conn: Connection object
26
+ :param create_table_sql: a CREATE TABLE statement
27
+ :return:
28
+ """
29
+ try:
30
+ c = conn.cursor()
31
+ c.execute(f"""DROP TABLE IF EXISTS {table_name}""")
32
+ c.execute(create_table_sql)
33
+ except Error as e:
34
+ print(e)
35
+
36
+
37
+ def insert_values(conn, task, sql):
38
+
39
+ cur = conn.cursor()
40
+ cur.execute(sql, task)
41
+ conn.commit()
42
+ return cur.lastrowid
43
+
44
+
45
+ def populate(csv_file, db_file, table_insert):
46
+
47
+ conn = create_connection(db_file)
48
+ with conn:
49
+
50
+
51
+ with open(csv_file, mode ='r')as file:
52
+ csvfile = csv.reader(file)
53
+
54
+ for n,lines in enumerate(csvfile):
55
+
56
+ if n>0:
57
+ lines = tuple(i for i in lines)
58
+ insert_values(conn, lines, table_insert)
59
+
60
+ else:pass
61
+
62
+ def main():
63
+
64
+ name = "sql_pdf.xlsx"
65
+ excel_file = (pd.read_excel(name))
66
+ csv_file = f"""{name.split(".")[0]}.csv"""
67
+ excel_file.to_csv(csv_file,
68
+ index=None,
69
+ header=True)
70
+ column = [x for x in excel_file.columns]
71
+ column_type = {}
72
+ type_map = {
73
+ "<class 'str'>": "TEXT",
74
+ "<class 'int'>": "INTEGER",
75
+ "<class 'float'>": "REAL",
76
+ }
77
+
78
+
79
+ for i in range(len(column)):
80
+ datatype = {}
81
+ for j in excel_file.values:
82
+ if type(j[i]) not in list(datatype.keys()):datatype[type(j[i])] = 1
83
+ else: datatype[type(j[i])] += 1
84
+
85
+ ma_x = 0
86
+ max_type = "<class 'str'>"
87
+
88
+ for k in list(datatype.keys()):
89
+ if ma_x < datatype[k]:max_type = str(k)
90
+
91
+ try:
92
+ column_type[column[i]] = type_map[max_type]
93
+
94
+ except KeyError:
95
+ column_type[column[i]] = "TEXT"
96
+
97
+ print(column_type)
98
+
99
+ table_construct = f"""CREATE TABLE IF NOT EXISTS {name.split(".")[0]}( """
100
+ table_insert = f"""INSERT INTO {name.split(".")[0]}("""
101
+ table_values = f"""VALUES ("""
102
+ for l in list(column_type.keys()):
103
+ table_construct += f"""{l} {column_type[l]}, """
104
+ table_insert += f"""{l}, """
105
+ table_values += "?, "
106
+ table_construct = f"""{table_construct[:-2]});"""
107
+ table_values = f"""{table_values[:-2]})"""
108
+ table_insert = f"""{table_insert[:-2]})\n{table_values}"""
109
+
110
+ print(table_construct)
111
+ print("\n\n", table_insert)
112
+
113
+ database = f"""{name.split(".")[0]}.db"""
114
+
115
+ conn = create_connection(database)
116
+
117
+ # create tables
118
+ if conn is not None:
119
+ # create projects table
120
+ create_table(conn, table_construct, name.split(".")[0])
121
+ else:
122
+ print("Error! cannot create the database connection.")
123
+
124
+
125
+ populate(csv_file, database, table_insert)
126
+
127
+ if __name__ == '__main__':
128
+ main()
sql_pdf.db ADDED
Binary file (8.19 kB). View file
 
sql_pdf.xlsx ADDED
Binary file (11.1 kB). View file
 
txt2sql_code3.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from sqlite3 import Error
3
+ from peft import AutoPeftModelForCausalLM
4
+ from transformers import AutoTokenizer, BitsAndBytesConfig
5
+ from transformers import AutoModelForCausalLM
6
+ from openai import OpenAI
7
+ import google.generativeai as genai
8
+
9
+ class SQLPromptModel:
10
+ def __init__(self, model_dir, database):
11
+ self.model_dir = model_dir
12
+ self.database = database
13
+ # peft_model_dir = self.model_dir
14
+ bnb_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_compute_dtype="float16",
18
+ bnb_4bit_use_double_quant=True,
19
+ )
20
+ # self.model = AutoPeftModelForCausalLM.from_pretrained(
21
+ # peft_model_dir, low_cpu_mem_usage=True, quantization_config=bnb_config
22
+ # )
23
+ # self.tokenizer = AutoTokenizer.from_pretrained(peft_model_dir)
24
+ # self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
25
+ # self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
26
+ self.chatgpt_client = OpenAI(api_key="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")
27
+ self.genai = genai
28
+ self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA")
29
+ self.genai_model = genai.GenerativeModel('gemini-pro')
30
+
31
+ self.conn = sqlite3.connect(self.database)
32
+
33
+ def fetch_table_schema(self, table_name):
34
+ """Fetch the schema of a table from the database."""
35
+ cursor = self.conn.cursor()
36
+ cursor.execute(f"PRAGMA table_info({table_name})")
37
+ schema = cursor.fetchall()
38
+ if schema:
39
+ return schema
40
+ else:
41
+ print(f"Table {table_name} does not exist or has no schema.")
42
+ return None
43
+
44
+ def text2sql(self, schema, user_prompt, inp_prompt=None):
45
+ """Generate SQL query based on user prompt and table schema.inp_prompt is for gradio purpose"""
46
+ table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
47
+
48
+ prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
49
+ Using valid SQLite, write a response that appropriately completes the request for the provided tables.
50
+ ### Instruction: {user_prompt} ###
51
+ Input: CREATE TABLE sql_pdf({table_columns});
52
+ ### Response: (Return only query , nothing extra)"""
53
+
54
+ if inp_prompt is not None :
55
+ prompt = prompt.replace(user_prompt, inp_prompt + " ")
56
+ else:
57
+ inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
58
+ if inp_prompt:
59
+ prompt = prompt.replace(user_prompt, inp_prompt + " ")
60
+
61
+ """Text to SQL query generation"""
62
+ input_ids = self.tokenizer(
63
+ prompt, return_tensors="pt", truncation=True
64
+ ).input_ids.to(next(self.model.parameters()).device) # Move input to the device of the model
65
+ outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200)
66
+ response = self.tokenizer.batch_decode(
67
+ outputs.detach().cpu().numpy(), skip_special_tokens=True
68
+ )[0][:]
69
+ return response[len(prompt):]
70
+
71
+ def text2sql_chatgpt(self, schema, user_prompt, inp_prompt=None):
72
+ table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
73
+
74
+ prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
75
+ Using valid SQLite, write a response that appropriately completes the request for the provided tables.
76
+ ### Instruction: {user_prompt} ###
77
+ Input: CREATE TABLE sql_pdf({table_columns});
78
+ ### Response: (Return only generated query based on user_prompt , nothing extra)"""
79
+
80
+ if inp_prompt is not None :
81
+ prompt = prompt.replace(user_prompt, inp_prompt + " ")
82
+ else:
83
+ inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
84
+ if inp_prompt:
85
+ prompt = prompt.replace(user_prompt, inp_prompt + " ")
86
+ print(prompt)
87
+ completion = self.chatgpt_client.chat.completions.create(
88
+ model="gpt-3.5-turbo",
89
+ messages=[
90
+ {"role": "system", "content": "You are a expert SQL developer , generate a sql query and return it"},
91
+ {"role": "user", "content": prompt }
92
+ ]
93
+ )
94
+ return completion.choices[0].message.content
95
+
96
+ def text2sql_gemini(self, schema, user_prompt, inp_prompt=None):
97
+ table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
98
+
99
+ prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
100
+ Using valid SQLite, write a response that appropriately completes the request for the provided tables.
101
+ ### Instruction: {user_prompt} ###
102
+ Input: CREATE TABLE sql_pdf({table_columns});
103
+ ### Response: (Return only generated query based on user_prompt , nothing extra)"""
104
+
105
+ if inp_prompt is not None :
106
+ prompt = prompt.replace(user_prompt, inp_prompt + " ")
107
+ else:
108
+ inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
109
+ if inp_prompt:
110
+ prompt = prompt.replace(user_prompt, inp_prompt + " ")
111
+ print(prompt)
112
+ completion = self.genai_model.generate_content(prompt)
113
+ generated_query=completion.text
114
+ start_index = generated_query.find("SELECT")
115
+ end_index = generated_query.find(";", start_index) + 1
116
+ print(start_index,end_index)
117
+ if start_index != -1 and end_index != 0:
118
+ return generated_query[start_index:end_index]
119
+ else:
120
+ return generated_query
121
+
122
+
123
+
124
+ def execute_query(self, query):
125
+ """Executing the query on database and returning rows and columns."""
126
+ print(query)
127
+ cur = self.conn.cursor()
128
+ cur.execute(query)
129
+ col = [header[0] for header in cur.description]
130
+ dash = "-" * sum(len(col_name) + 4 for col_name in col)
131
+ print(tuple(col))
132
+ print(dash)
133
+ rows = []
134
+ for member in cur:
135
+ rows.append(member)
136
+ print(member)
137
+ cur.close()
138
+ self.conn.commit()
139
+ # print(rows)
140
+ return rows, col
141
+
142
+ if __name__ == "__main__":
143
+ model_dir = "multi_table_demo/checkpoint-2600"
144
+ database = r"sql_pdf.db"
145
+ sql_model = SQLPromptModel(model_dir, database)
146
+ user_prompt = "Give complete details of properties in India"
147
+ while True:
148
+ table_schema = sql_model.fetch_table_schema("sql_pdf")
149
+ if table_schema:
150
+ # query = sql_model.text2sql(table_schema, user_prompt)
151
+ # query = sql_model.text2sql_chatgpt(table_schema, user_prompt)
152
+ query = sql_model.text2sql_gemini(table_schema, user_prompt)
153
+ print(query)
154
+ sql_model.execute_query(query)
155
+
156
+ sql_model.conn.close()
157
+
158
+