Samka1u commited on
Commit
84780b0
·
verified ·
1 Parent(s): ec3e46a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import shutil
4
+ import pandas as pd
5
+ import gradio as gr
6
+ from config import PATHS
7
+ from secret_keys import *
8
+ from smolagents import CodeAgent, InferenceClientModel, tool
9
+ from sqlalchemy import (
10
+ create_engine,
11
+ MetaData,
12
+ Table,
13
+ Column,
14
+ String,
15
+ Integer,
16
+ Float,
17
+ insert,
18
+ inspect,
19
+ text,
20
+ exc,
21
+ )
22
+ engine = create_engine("sqlite:///agentDB.db")
23
+ metadata_obj = MetaData()
24
+
25
+ def load_rows():
26
+ """
27
+ Loads dictionary with orient = list populated with column names as key and all the values in the column in a list.
28
+ Args:
29
+ None
30
+ Returns:
31
+ col_names (list): The list of column names.
32
+ rows (list): list of rows containing values from each column.
33
+ num_cols (int): Number of columns.
34
+ """
35
+ # load dict from pickle
36
+ with open(PATHS.PKL_FILE_PATH, "rb") as f:
37
+ sql_dict = pickle.load(f)
38
+
39
+ # collect column names
40
+ col_names = list(sql_dict.keys())
41
+ num_cols = len(col_names)
42
+
43
+ # Ensure the dictionary is not empty
44
+ if not col_names:
45
+ raise ValueError("The dictionary is empty.")
46
+
47
+ # collect table rows from dict
48
+ num_rows = len(sql_dict[col_names[0]])
49
+ rows = []
50
+ # Iterate through dict collecting each columns info as a row
51
+ for i in range(num_rows):
52
+ row = {}
53
+ for col in col_names:
54
+ value = sql_dict[col][i]
55
+ row[col] = value
56
+ rows.append(row)
57
+ return col_names, rows, num_cols
58
+
59
+ def insert_rows(rows, table, engine = engine):
60
+ """
61
+ Insert rows into table.
62
+ Args:
63
+ rows (dict): Dictionary of rows to be inserted with column names as keys.
64
+ table (sqlalchemy.Table): Table to be inserted.
65
+ engine (sqlalchemy.engine): SQLAlchemy engine to be used.
66
+ Returns:
67
+ None
68
+ """
69
+ for row in rows:
70
+ stmt = insert(table).values(**row)
71
+ with engine.begin() as connection:
72
+ connection.execute(stmt)
73
+
74
+
75
+ def create_dynamic_table(table_name, columns):
76
+ """
77
+ Creates an sql table dynamically.
78
+ Args:
79
+ table_name (String): name of the table
80
+ columns (list): list of column names
81
+ Returns:
82
+ table: The table object.
83
+ """
84
+ table = Table(
85
+ table_name,
86
+ metadata_obj,
87
+ Column('id', Integer, primary_key=True),
88
+ *[Column(name, type_) for name, type_ in columns.items()],
89
+ extend_existing=True
90
+ )
91
+ return table
92
+
93
+
94
+ def update_table(column_type):
95
+ """
96
+ Updates table with columns from gradio textbox. Calls load_rows() to read pkl file and get rows dict, column names, and number.
97
+ Raises relevant error if number of data types does not match number of columns, if the user did not input a recognized data type, and if there are any errors inserting the rows.
98
+ Args:
99
+ column_type (String): The user inputed comma separated column data types.
100
+ Returns:
101
+ (String): Sucess message when no errors, the error that was raised when failure.
102
+ """
103
+ # load rows for the table
104
+ col_names, rows, num_cols = load_rows()
105
+ # split str into list of data types
106
+ dataType_list = column_type.split(",")
107
+ try:
108
+ if len(dataType_list) != len(col_names):
109
+ raise ValueError()
110
+ for i in range(len(dataType_list)):
111
+ match dataType_list[i].strip():
112
+ case "String":
113
+ dataType_list[i] = String
114
+ case "Integer":
115
+ dataType_list[i] = Integer
116
+ case "Float":
117
+ dataType_list[i] = Float
118
+ if dataType_list[i] != String and dataType_list[i] != Float and dataType_list[i] != Integer:
119
+ raise TypeError()
120
+ except TypeError as e:
121
+ return f"A data type you entered was invalid."
122
+ except ValueError as e:
123
+ return f"{e}. Number of data types ({len(dataType_list)}) does not match number of columns ({len(col_names)})."
124
+
125
+ # Dynamically create the columns dictionary
126
+ columns = {
127
+ col_name: dataType_list[i] # Map column name to data type by index
128
+ for i, col_name in enumerate(col_names)
129
+ }
130
+ len_cols = len(columns)
131
+ dynamic_table = create_dynamic_table(PATHS.TABLE_NAME, columns)
132
+ metadata_obj.create_all(engine)
133
+
134
+ try:
135
+ insert_rows(rows, dynamic_table)
136
+ except exc.CompileError as e:
137
+ return (f"{e}.")
138
+ except exc.OperationalError as e:
139
+ return (f"{e}. agentDB has already had it's schema defined.")
140
+ return "Row insertion succesful"
141
+
142
+
143
+ def table_description():
144
+ """
145
+ Generates a description of the table to feed to agent prompt.
146
+ Args:
147
+ None
148
+ Returns:
149
+ table_description (String): The table's column names and their data types.
150
+ """
151
+ inspector = inspect(engine)
152
+ try:
153
+ columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(PATHS.TABLE_NAME)]
154
+ table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
155
+ except exc.NoSuchTableError as e:
156
+ return f"NoSuchTableError: {e}. The referenced table does not exist."
157
+ return table_description
158
+
159
+ def table_check()-> str:
160
+ """
161
+ Verify the table exists. Returns a string which will say if the table exists or not.
162
+ Args:
163
+ None
164
+ Returns:
165
+ (String): A message containing table status.
166
+ """
167
+ try:
168
+ inspector = inspect(engine)
169
+ if inspector.has_table(PATHS.TABLE_NAME):
170
+ return f"Table '{PATHS.TABLE_NAME}' exists."
171
+ else:
172
+ raise exc.NoSuchTableError()
173
+ except exc.NoSuchTableError as e:
174
+ return f"NoSuchTableError: {e} The referenced table does not exist."
175
+
176
+
177
+ @tool
178
+ def sql_engine(query: str) -> str:
179
+ """
180
+ Allows you to perform SQL queries on the table. Returns a string representation of the result.
181
+ The Table is named agent_table.
182
+ Args:
183
+ query: The query to be performed on the table. This should always be correct SQL.
184
+ """
185
+ output = ""
186
+
187
+ with engine.begin() as con:
188
+ try:
189
+ rows = con.execution_options(autocommit=True).execute(text(query))
190
+ if not rows:
191
+ return "No rows found, include the `RETURNING` keyword to ensure the result object always returns rows."
192
+ else:
193
+ for row in rows:
194
+ output += str(row) + "\n"
195
+ except exc.SQLAlchemyError as e:
196
+ return f"{e}. Include the `RETURNING` keyword to ensure the result object always returns rows."
197
+ return output
198
+
199
+
200
+ def agent_setup():
201
+ """
202
+ Initialize the inference client, as well as the sql agent.
203
+ Args:
204
+ None
205
+ Returns:
206
+ sql_agent (Agent): The agent that will be used for inference.
207
+ """
208
+ sql_model = InferenceClientModel(
209
+ api_key=NEBIUS_API_KEY,
210
+ model_id="Qwen/Qwen3-235B-A22B", # Qwen/Qwen3-4B
211
+ provider="nebius",
212
+ )
213
+ # define SQL Agent
214
+ sql_agent = CodeAgent(
215
+ tools=[sql_engine],
216
+ model=sql_model,
217
+ max_steps=5,
218
+ )
219
+ return sql_agent
220
+
221
+ def run_prompt(prompt, history):
222
+ """
223
+ Initialize the inference client, as well as the sql agent.
224
+ Args:
225
+ prompt (String): The user's query to be fed to the agent.
226
+ history (Any):
227
+ Returns:
228
+ sql_agent (Agent): The agent that will be used for inference.
229
+ """
230
+ table_descrip = table_description()
231
+ table_status = table_check()
232
+ if "NoSuchTableError" in table_status:
233
+ return table_status + " Check the table has the expected name and it is consistent."
234
+ return agent.run(prompt + f". Always wrap the result in relevant context and enforce the results object returning rows. Table description is as follows:{table_descrip}")
235
+
236
+
237
+ def vote(data: gr.LikeData):
238
+ """
239
+ Provide feedback to agent's response.
240
+ Args:
241
+ data (LikeData): carries information about the .like() event.
242
+ Returns:
243
+ None
244
+ """
245
+ if data.liked:
246
+ print("You upvoted this response: " + data.value["value"])
247
+ else:
248
+ print("You downvoted this response: " + data.value["value"])
249
+
250
+
251
+ def process_file(fileobj):
252
+ """
253
+ Save file to temporary folder.
254
+ Args:
255
+ fileobj (Any): The uploaded file.
256
+ Returns:
257
+ None (calls csv_2_dict)
258
+ """
259
+ csv_path = PATHS.TEMP_PATH + os.path.basename(fileobj)
260
+ # copy file to path
261
+ shutil.copyfile(fileobj.name, csv_path)
262
+ return csv_2_dict(csv_path)
263
+
264
+
265
+ def csv_2_dict(path):
266
+ """
267
+ Reads csv as a dataframe which is converted to a dictionary that is written to a pkl file in the temporary folder.
268
+ Args:
269
+ path (Any): The temporary file path.
270
+ Returns:
271
+ None
272
+ """
273
+ # read csv as dataframe then drop empties
274
+ df = pd.read_csv(path)
275
+ df_cleaned = df.dropna()
276
+ # convert dataframe to a dictionary and save as pickle file
277
+ table_data = df_cleaned.to_dict(orient='list')
278
+ with open(PATHS.PKL_FILE_PATH, "wb") as f:
279
+ pickle.dump(table_data, f)
280
+
281
+
282
+ def change_insert_mode(choice):
283
+ """
284
+ Drops table if user elects to upload a new table passes if no table to drop or user chooses to upload to existing table.
285
+ Args:
286
+ choice (Any): The name of the radio button the user has selected.
287
+ Returns:
288
+ None
289
+ """
290
+ table_status = table_check()
291
+ if choice == "Upload New" and not "NoSuchTableError" in table_status:
292
+ sql_engine(f"DROP TABLE {PATHS.TABLE_NAME};")
293
+ else:
294
+ pass
295
+
296
+ with gr.Blocks() as demo:
297
+ with gr.Tab("Table Setup"):
298
+ insert_mode = gr.Radio(["Upload New", "Upload to Existing"], label="Insertion Mode",
299
+ info="Warning selecting Upload New will immediately drop existing table, leaving unseleted will add to existing table.")
300
+ insert_mode.input(fn=change_insert_mode, inputs=insert_mode, outputs=None)
301
+ gr.Markdown("Next upload the csv:")
302
+ gr.Interface(
303
+ fn=process_file,
304
+ inputs=[
305
+ "file",
306
+ ],
307
+ outputs=None,
308
+ flagging_mode="never"
309
+ )
310
+ column_type = gr.Textbox(label="Enter column data types (String, Integer, Float) as a comma seperated list:")
311
+ column_type_message = gr.Textbox(label="Feedback:")
312
+ col_type_button = gr.Button("Submit")
313
+ col_type_button.click(update_table, inputs=column_type, outputs=[column_type_message, ])
314
+ with gr.Tab("Text2SQL Agent"):
315
+ chatbot = gr.Chatbot(type="messages", placeholder=f"<strong>Ask agent to perform a query.</strong>")
316
+ chatbot.like(vote, None, None)
317
+ gr.ChatInterface(fn=run_prompt, type="messages", chatbot=chatbot)
318
+
319
+ if __name__ == "__main__":
320
+ # initialize agent once
321
+ agent = agent_setup()
322
+
323
+ demo.launch(debug=True)