Spaces:
Running
Running
import os | |
import duckdb | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from transformers import HfEngine, ReactCodeAgent | |
from transformers.agents import Tool | |
from langsmith import traceable | |
from langchain import hub | |
# Height of the Tabs Text Area | |
TAB_LINES = 8 | |
#----------CONNECT TO DATABASE---------- | |
md_token = os.getenv('MD_TOKEN') | |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
#--------------------------------------- | |
#-------LOAD HUGGINGFACE MODEL------- | |
models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", | |
"meta-llama/Llama-3.1-70B-Instruct"] | |
model_loaded = False | |
for model in models: | |
try: | |
llm_engine = HfEngine(model=model) | |
info = llm_engine.client.get_endpoint_info() | |
model_loaded = True | |
break | |
except Exception as e: | |
print(f"Error for model {model}: {e}") | |
continue | |
if not model_loaded: | |
gr.Warning(f"β None of the model form {models} are available. {e}") | |
#--------------------------------------- | |
#-----LOAD PROMPT FROM LANCHAIN HUB----- | |
prompt = hub.pull("viz-prompt") | |
#------------------------------------- | |
#--------------ALL UTILS---------------- | |
def get_schemas(): | |
schemas = conn.execute(""" | |
SELECT DISTINCT schema_name | |
FROM information_schema.schemata | |
WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
""").fetchall() | |
return [item[0] for item in schemas] | |
# Get Tables | |
def get_tables(schema_name): | |
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
return [table[0] for table in tables] | |
# Update Tables | |
def update_tables(schema_name): | |
tables = get_tables(schema_name) | |
return gr.update(choices=tables) | |
# Get Schema | |
def get_table_schema(table): | |
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
ddl_create = result.iloc[0,0] | |
parent_database = result.iloc[0,1] | |
schema_name = result.iloc[0,2] | |
full_path = f"{parent_database}.{schema_name}.{table}" | |
if schema_name != "main": | |
old_path = f"{schema_name}.{table}" | |
else: | |
old_path = table | |
ddl_create = ddl_create.replace(old_path, full_path) | |
return ddl_create, full_path | |
class SQLExecutorTool(Tool): | |
name = "sql_engine" | |
inputs = { | |
"query": { | |
"type": "text", | |
"description": f"The query to perform. This should be correct DuckDB SQL.", | |
} | |
} | |
description = """Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result.""" | |
output_type = "pandas.core.frame.DataFrame" | |
def forward(self, query: str) -> str: | |
output_df = conn.sql(query).df() | |
return output_df | |
tool = SQLExecutorTool() | |
def process_outputs(output) : | |
return { | |
'sql': output.get('sql', None), | |
'code': output.get('code', None) | |
} | |
def get_visualization(question, schema, table_name): | |
agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True, | |
additional_authorized_imports=['matplotlib.pyplot', | |
'pandas', 'plotly.express', | |
'seaborn'], max_iterations=10) | |
results = agent.run( | |
task= prompt.format(question=question, schema=schema, table_name=table_name) | |
) | |
return results | |
#--------------------------------------- | |
def main(table, text_query): | |
# Empty Fig | |
fig, ax = plt.subplots() | |
ax.set_axis_off() | |
schema, table_name = get_table_schema(table) | |
try: | |
output = get_visualization(question=text_query, schema=schema, table_name=table_name) | |
fig = output.get('fig', None) | |
generated_sql = output.get('sql', None) | |
data = output.get('data', None) | |
except Exception as e: | |
gr.Warning(f"β Unable to generate the visualization. {e}") | |
return fig, generated_sql, data | |
custom_css = """ | |
.gradio-container { | |
background-color: #f0f4f8; | |
} | |
.logo { | |
max-width: 200px; | |
margin: 20px auto; | |
display: block; | |
} | |
.gr-button { | |
background-color: #4a90e2 !important; | |
} | |
.gr-button:hover { | |
background-color: #3a7bc8 !important; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
gr.Markdown(""" | |
<div style='text-align: center;'> | |
<strong style='font-size: 36px;'>DataViz Agent</strong> | |
<br> | |
<span style='font-size: 20px;'>Visualize SQL queries based on a given text for the dataset.</span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
with gr.Column(scale=2): | |
query_input = gr.Textbox(lines=3, label="Text Query", placeholder="Enter your text query here...") | |
with gr.Row(): | |
with gr.Column(scale=7): | |
pass | |
with gr.Column(scale=1): | |
generate_query_button = gr.Button("Run Query", variant="primary") | |
with gr.Tabs(): | |
with gr.Tab("Plot"): | |
result_plot = gr.Plot() | |
with gr.Tab("SQL"): | |
generated_sql = gr.Textbox(lines=TAB_LINES, label="Generated SQL", value="", interactive=False, | |
autoscroll=False) | |
with gr.Tab("Data"): | |
data = gr.Dataframe(label="Data", interactive=False) | |
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown) | |
generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot, generated_sql, data]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |