Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import shutil | |
import pandas as pd | |
import gradio as gr | |
from config import PATHS | |
from smolagents import CodeAgent, InferenceClientModel, tool | |
from sqlalchemy import ( | |
create_engine, | |
MetaData, | |
Table, | |
Column, | |
String, | |
Integer, | |
Float, | |
insert, | |
inspect, | |
text, | |
exc, | |
) | |
# initialize sql engine | |
engine = create_engine("sqlite:///agentDB.db") | |
metadata_obj = MetaData() | |
def load_rows(): | |
""" | |
Loads dictionary with orient = list populated with column names as key and all the values in the column in a list. | |
Args: | |
None | |
Returns: | |
col_names (list): The list of column names. | |
rows (list): list of rows containing values from each column. | |
num_cols (int): Number of columns. | |
""" | |
# load dict from pickle | |
with open(PATHS.PKL_FILE_PATH, "rb") as f: | |
sql_dict = pickle.load(f) | |
print(sql_dict) | |
# collect column names | |
col_names = list(sql_dict.keys()) | |
num_cols = len(col_names) | |
# Ensure the dictionary is not empty | |
if not col_names: | |
raise ValueError("The dictionary is empty.") | |
# collect table rows from dict | |
num_rows = len(sql_dict[col_names[0]]) | |
rows = [] | |
# Iterate through dict collecting each columns info as a row | |
for i in range(num_rows): | |
row = {} | |
for col in col_names: | |
value = sql_dict[col][i] | |
row[col] = value | |
rows.append(row) | |
return col_names, rows, num_cols | |
def insert_rows(rows, table, engine = engine): | |
""" | |
Insert rows into table. | |
Args: | |
rows (dict): Dictionary of rows to be inserted with column names as keys. | |
table (sqlalchemy.Table): Table to be inserted. | |
engine (sqlalchemy.engine): SQLAlchemy engine to be used. | |
Returns: | |
None | |
""" | |
for row in rows: | |
stmt = insert(table).values(**row) | |
with engine.begin() as connection: | |
connection.execute(stmt) | |
def create_dynamic_table(table_name, columns): | |
""" | |
Creates an sql table dynamically. | |
Args: | |
table_name (String): name of the table | |
columns (list): list of column names | |
Returns: | |
table: The table object. | |
""" | |
print(columns) | |
table = Table( | |
table_name, | |
metadata_obj, | |
Column('id', Integer, primary_key=True), | |
*[Column(name, type_) for name, type_ in columns.items()], | |
extend_existing=True | |
) | |
return table | |
def update_table(column_type): | |
""" | |
Updates table with columns from gradio textbox. Calls load_rows() to read pkl file and get rows dict, column names, and number. | |
Raises relevant error if number of data types does not match number of columns, if the user did not input a recognized data type, and if there are any errors inserting the rows. | |
Args: | |
column_type (String): The user inputed comma separated column data types. | |
Returns: | |
(String): Sucess message when no errors, the error that was raised when failure. | |
""" | |
# load rows for the table | |
col_names, rows, num_cols = load_rows() | |
# split str into list of data types | |
dataType_list = column_type.split(",") | |
try: | |
if len(dataType_list) != len(col_names): | |
raise ValueError() | |
for i in range(len(dataType_list)): | |
match dataType_list[i].strip(): | |
case "String": | |
dataType_list[i] = String | |
case "Integer": | |
dataType_list[i] = Integer | |
case "Float": | |
dataType_list[i] = Float | |
if dataType_list[i] != String and dataType_list[i] != Float and dataType_list[i] != Integer: | |
raise TypeError() | |
except TypeError as e: | |
return f"A data type you entered was invalid." | |
except ValueError as e: | |
return f"{e}. Number of data types ({len(dataType_list)}) does not match number of columns ({len(col_names)})." | |
# Dynamically create the columns dictionary | |
columns = { | |
col_name: dataType_list[i] # Map column name to data type by index | |
for i, col_name in enumerate(col_names) | |
} | |
len_cols = len(columns) | |
dynamic_table = create_dynamic_table(PATHS.TABLE_NAME, columns) | |
metadata_obj.create_all(engine) | |
try: | |
insert_rows(rows, dynamic_table) | |
except exc.CompileError as e: | |
return (f"{e}.") | |
except exc.OperationalError as e: | |
return (f"{e}. agentDB has already had it's schema defined.") | |
return "Row insertion succesful" | |
def table_description(): | |
""" | |
Generates a description of the table to feed to agent prompt. | |
Args: | |
None | |
Returns: | |
table_description (String): The table's column names and their data types. | |
""" | |
inspector = inspect(engine) | |
try: | |
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(PATHS.TABLE_NAME)] | |
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info]) | |
except exc.NoSuchTableError as e: | |
return f"NoSuchTableError: {e}. The referenced table does not exist." | |
return table_description | |
def table_check()-> str: | |
""" | |
Verify the table exists. Returns a string which will say if the table exists or not. | |
Args: | |
None | |
Returns: | |
(String): A message containing table status. | |
""" | |
inspector = inspect(engine) | |
try: | |
if inspector.has_table(PATHS.TABLE_NAME): | |
return f"Table '{PATHS.TABLE_NAME}' exists." | |
else: | |
raise exc.NoSuchTableError() | |
except exc.NoSuchTableError as e: | |
return f"NoSuchTableError: {e} The referenced table does not exist." | |
def sql_engine(query: str) -> str: | |
""" | |
Allows you to perform SQL queries on the table. Returns a string representation of the result. | |
The Table is named agent_table. | |
Args: | |
query: The query to be performed on the table. This should always be correct SQL. | |
""" | |
output = "" | |
with engine.begin() as con: | |
try: | |
rows = con.execution_options(autocommit=True).execute(text(query)) | |
if not rows: | |
return "No rows found, include the `RETURNING` keyword to ensure the result object always returns rows." | |
else: | |
for row in rows: | |
output += str(row) + "\n" | |
except exc.SQLAlchemyError as e: | |
return f"{e}. Include the `RETURNING` keyword to ensure the result object always returns rows." | |
return output | |
def agent_setup(): | |
NEBIUS_API_KEY = os.environ.get('NEBIUS_API_KEY') | |
""" | |
Initialize the inference client, as well as the sql agent. | |
Args: | |
None | |
Returns: | |
sql_agent (CodeAgent): The agent that will be used for inference. | |
""" | |
sql_model = InferenceClientModel( | |
api_key=NEBIUS_API_KEY, | |
model_id="Qwen/Qwen3-235B-A22B", # Qwen/Qwen3-4B | |
provider="nebius", | |
) | |
# define SQL Agent | |
sql_agent = CodeAgent( | |
tools=[sql_engine], | |
model=sql_model, | |
max_steps=5, | |
) | |
return sql_agent | |
def run_prompt(prompt, history): | |
""" | |
Initialize the inference client, as well as the sql agent. | |
Args: | |
prompt (String): The user's query to be fed to the agent. | |
history (Any): | |
Returns: | |
sql_agent (Agent): The agent that will be used for inference. | |
""" | |
table_descrip = table_description() | |
table_status = table_check() | |
if "NoSuchTableError" in table_status: | |
return table_status + " Check the table has the expected name and it is consistent." | |
return agent.run(prompt + f". Always wrap the result in relevant context and enforce the results object returning rows. Table description is as follows:{table_descrip}") | |
def vote(data: gr.LikeData): | |
""" | |
Provide feedback to agent's response. | |
Args: | |
data (LikeData): carries information about the .like() event. | |
Returns: | |
None | |
""" | |
if data.liked: | |
print("You upvoted this response: " + data.value["value"]) | |
else: | |
print("You downvoted this response: " + data.value["value"]) | |
def process_file(fileobj): | |
""" | |
Save file to temporary folder. | |
Args: | |
fileobj (Any): The uploaded file. | |
Returns: | |
None (calls csv_2_dict) | |
""" | |
csv_path = PATHS.TEMP_PATH + os.path.basename(fileobj) | |
# copy file to path | |
shutil.copyfile(fileobj.name, csv_path) | |
return csv_2_dict(csv_path) | |
def csv_2_dict(path): | |
""" | |
Reads csv as a dataframe which is converted to a dictionary that is written to a pkl file in the temporary folder. | |
Args: | |
path (Any): The temporary file path. | |
Returns: | |
None | |
""" | |
# read csv as dataframe then drop empties | |
df = pd.read_csv(path) | |
df_cleaned = df.dropna() | |
# convert dataframe to a dictionary and save as pickle file | |
table_data = df_cleaned.to_dict(orient='list') | |
with open(PATHS.PKL_FILE_PATH, "wb") as f: | |
pickle.dump(table_data, f) | |
def change_insert_mode(choice): | |
""" | |
Drops table if user elects to upload a new table passes if no table to drop or user chooses to upload to existing table. | |
Args: | |
choice (Any): The name of the radio button the user has selected. | |
Returns: | |
None | |
""" | |
table_status = table_check() | |
if choice == "Upload New" and not "NoSuchTableError" in table_status: | |
# sql_engine(f"DROP COLUMN *;") | |
sql_engine(f"DROP TABLE {PATHS.TABLE_NAME};") | |
else: | |
pass | |
with gr.Blocks() as demo: | |
with gr.Tab("Table Setup"): | |
insert_mode = gr.Radio(["Upload New", "Upload to Existing"], label="Insertion Mode", | |
info="Warning selecting Upload New will immediately drop existing table, leaving unselected will add to existing table.") | |
insert_mode.input(fn=change_insert_mode, inputs=insert_mode, outputs=None) | |
gr.Markdown("Next upload the csv:") | |
gr.Interface( | |
fn=process_file, | |
inputs=[ | |
"file", | |
], | |
outputs=None, | |
flagging_mode="never" | |
) | |
column_type = gr.Textbox(label="Enter column data types (String, Integer, Float) as a comma seperated list:") | |
column_type_message = gr.Textbox(label="Feedback:") | |
col_type_button = gr.Button("Submit") | |
col_type_button.click(update_table, inputs=column_type, outputs=[column_type_message, ]) | |
with gr.Tab("Text2SQL Agent"): | |
chatbot = gr.Chatbot(type="messages", placeholder=f"<strong>Ask agent to perform a query.</strong>") | |
chatbot.like(vote, None, None) | |
gr.ChatInterface(fn=run_prompt, type="messages", chatbot=chatbot) | |
if __name__ == "__main__": | |
# initialize agent | |
agent = agent_setup() | |
demo.launch(debug=True) |