Spaces:
Sleeping
Sleeping
Mustehson
commited on
Commit
·
3e27538
1
Parent(s):
277bf9b
Added Prompt 2
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
from transformers import HfEngine, ReactCodeAgent
|
6 |
from transformers.agents import Tool
|
|
|
7 |
|
8 |
# Height of the Tabs Text Area
|
9 |
TAB_LINES = 8
|
@@ -64,23 +65,9 @@ def get_visualization(question, tool, schema, table_name):
|
|
64 |
'pandas', 'plotly.express',
|
65 |
'seaborn'], max_iterations=10)
|
66 |
fig = agent.run(
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
Here are the steps you should follow while writing code for Visualization:
|
71 |
-
1. Select the most appropriate chart type for data. Use bar charts for categorical comparisons, line charts for trends over time, scatter plots for relationships between variables, pie charts for proportions, histograms for distribution analysis, and box plots for visualizing data spread and outliers.
|
72 |
-
2. Ensure clear and appropriate labels, colors, and design elements, keeping visual elements legible and uncluttered.
|
73 |
-
3. Follow best practices, avoiding unnecessary visual distractions (chartjunk).
|
74 |
-
4. Ensure the code is error-free, with correct fields, transformations, and aesthetics.
|
75 |
-
5. Use descriptive and accurate x and y axis labels that reflect the data.
|
76 |
-
6. Ensure units of measurement are clearly indicated on axes (e.g., %, $, cm).
|
77 |
-
7. Ensure that categorical data is plotted on one axis and numerical data on the other, with appropriate labels that clearly represent the data being visualized.
|
78 |
-
8. When plotting categorical data, arrange categories in a meaningful order (e.g., by size, time, or frequency) rather than randomly.
|
79 |
-
9. Ensure that the categorical data are plotted on the x-axis, and the frequencies (numerical data) are plotted on the y-axis.
|
80 |
-
10. Use seaborn
|
81 |
-
11. In the end you have to return a dict which contain final fig as fig key, Generated SQL as sql key, Data as a dataframe with data key using the `final_answer` tool e.g. final_answer(answer={"fig": fig, "sql": sql, "data": data})''',
|
82 |
-
|
83 |
-
task= f'{question}',
|
84 |
schema= f'{schema}',
|
85 |
table_name= f'{table_name}',
|
86 |
)
|
@@ -96,6 +83,7 @@ class SQLExecutorTool(Tool):
|
|
96 |
"description": f"The query to perform. This should be correct DuckDB SQL.",
|
97 |
}
|
98 |
}
|
|
|
99 |
output_type = "pandas.core.frame.DataFrame"
|
100 |
|
101 |
def forward(self, query: str) -> str:
|
@@ -110,8 +98,6 @@ def main(table, text_query):
|
|
110 |
ax.set_axis_off()
|
111 |
|
112 |
schema, table_name = get_table_schema(table)
|
113 |
-
tool.description = f"""Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result."""
|
114 |
-
|
115 |
|
116 |
try:
|
117 |
output = get_visualization(question=text_query, tool=tool, schema=schema, table_name=table_name)
|
@@ -121,7 +107,6 @@ def main(table, text_query):
|
|
121 |
except Exception as e:
|
122 |
gr.Warning(f"❌ Unable to generate the visualization. {e}")
|
123 |
|
124 |
-
|
125 |
return fig, generated_sql, data
|
126 |
|
127 |
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
from transformers import HfEngine, ReactCodeAgent
|
6 |
from transformers.agents import Tool
|
7 |
+
from prompt import PROMPT
|
8 |
|
9 |
# Height of the Tabs Text Area
|
10 |
TAB_LINES = 8
|
|
|
65 |
'pandas', 'plotly.express',
|
66 |
'seaborn'], max_iterations=10)
|
67 |
fig = agent.run(
|
68 |
+
|
69 |
+
task=PROMPT,
|
70 |
+
query_description= f'{question}',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
schema= f'{schema}',
|
72 |
table_name= f'{table_name}',
|
73 |
)
|
|
|
83 |
"description": f"The query to perform. This should be correct DuckDB SQL.",
|
84 |
}
|
85 |
}
|
86 |
+
description = """Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result."""
|
87 |
output_type = "pandas.core.frame.DataFrame"
|
88 |
|
89 |
def forward(self, query: str) -> str:
|
|
|
98 |
ax.set_axis_off()
|
99 |
|
100 |
schema, table_name = get_table_schema(table)
|
|
|
|
|
101 |
|
102 |
try:
|
103 |
output = get_visualization(question=text_query, tool=tool, schema=schema, table_name=table_name)
|
|
|
107 |
except Exception as e:
|
108 |
gr.Warning(f"❌ Unable to generate the visualization. {e}")
|
109 |
|
|
|
110 |
return fig, generated_sql, data
|
111 |
|
112 |
|
prompt.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROMPT = '''
|
2 |
+
Here are the steps you should follow while writing code for Visualization:
|
3 |
+
1. You have access to the database with the `sql_engine` tool, which allows you to run DuckDB SQL queries and return results as a df.
|
4 |
+
2. Query the database using `sql_engine`, print the first 5 rows to inspect the data.
|
5 |
+
3. Select the most appropriate chart type for the data:
|
6 |
+
- Use bar charts for categorical comparisons, line charts for trends over time, scatter plots for relationships between variables, pie charts for proportions, histograms for distribution, and box plots for data spread and outliers.
|
7 |
+
4. Analyze the data and choose the best visualization type to answer the query.
|
8 |
+
5. Always include a plot in your answer.
|
9 |
+
6. Use Seaborn for the plots.
|
10 |
+
7. In the end, return a dictionary containing the final figure (`fig` key), the generated SQL (`sql` key), and the data as a dataframe (`data` key) using the `final_answer` tool, e.g. `final_answer(answer={{"fig": 'fig.png', "sql": sql, "data": data}})`.
|
11 |
+
|
12 |
+
Example:
|
13 |
+
|
14 |
+
```python
|
15 |
+
# Input query
|
16 |
+
query_description = 'Average tip amount based on the ride time length in minutes.'
|
17 |
+
|
18 |
+
# SQL Query to get ride time length and average tip amount
|
19 |
+
query = """
|
20 |
+
SELECT
|
21 |
+
EXTRACT(EPOCH FROM (tpep_dropoff_datetime - tpep_pickup_datetime)) / 60 AS ride_time_length,
|
22 |
+
AVG(tip_amount) AS avg_tip_amount
|
23 |
+
FROM
|
24 |
+
sample_data.nyc.taxi
|
25 |
+
GROUP BY
|
26 |
+
EXTRACT(EPOCH FROM (tpep_dropoff_datetime - tpep_pickup_datetime)) / 60
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Execute the query using the sql_engine tool
|
30 |
+
df = sql_engine(query=query)
|
31 |
+
|
32 |
+
# Print the result to observe the data
|
33 |
+
print(df)
|
34 |
+
|
35 |
+
# Create a line plot using seaborn
|
36 |
+
import seaborn as sns
|
37 |
+
import matplotlib.pyplot as plt
|
38 |
+
|
39 |
+
plt.figure(figsize=(10,6))
|
40 |
+
sns.lineplot(x="ride_time_length", y="avg_tip_amount", data=df)
|
41 |
+
|
42 |
+
# Set the title and labels
|
43 |
+
plt.title("Average Tip Amount vs Ride Time Length")
|
44 |
+
plt.xlabel("Ride Time Length (minutes)")
|
45 |
+
plt.ylabel("Average Tip Amount")
|
46 |
+
|
47 |
+
# Print the plot to observe the results
|
48 |
+
print("Plot created")
|
49 |
+
|
50 |
+
# Since we are required to return a fig, sql, and data, let's store the plot in a variable
|
51 |
+
fig = plt.gcf()
|
52 |
+
|
53 |
+
# Store the query in a variable
|
54 |
+
sql = query
|
55 |
+
|
56 |
+
# Store the dataframe in a variable
|
57 |
+
data = df
|
58 |
+
|
59 |
+
# Return the final answer
|
60 |
+
final_answer(answer={"fig": fig, "sql": sql, "data": data})
|
61 |
+
```
|
62 |
+
'''
|