Mustehson commited on
Commit
3e27538
·
1 Parent(s): 277bf9b

Added Prompt 2

Browse files
Files changed (2) hide show
  1. app.py +5 -20
  2. prompt.py +62 -0
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import matplotlib.pyplot as plt
5
  from transformers import HfEngine, ReactCodeAgent
6
  from transformers.agents import Tool
 
7
 
8
  # Height of the Tabs Text Area
9
  TAB_LINES = 8
@@ -64,23 +65,9 @@ def get_visualization(question, tool, schema, table_name):
64
  'pandas', 'plotly.express',
65
  'seaborn'], max_iterations=10)
66
  fig = agent.run(
67
-
68
- instruction='''
69
- THINK STEP BY STEP
70
- Here are the steps you should follow while writing code for Visualization:
71
- 1. Select the most appropriate chart type for data. Use bar charts for categorical comparisons, line charts for trends over time, scatter plots for relationships between variables, pie charts for proportions, histograms for distribution analysis, and box plots for visualizing data spread and outliers.
72
- 2. Ensure clear and appropriate labels, colors, and design elements, keeping visual elements legible and uncluttered.
73
- 3. Follow best practices, avoiding unnecessary visual distractions (chartjunk).
74
- 4. Ensure the code is error-free, with correct fields, transformations, and aesthetics.
75
- 5. Use descriptive and accurate x and y axis labels that reflect the data.
76
- 6. Ensure units of measurement are clearly indicated on axes (e.g., %, $, cm).
77
- 7. Ensure that categorical data is plotted on one axis and numerical data on the other, with appropriate labels that clearly represent the data being visualized.
78
- 8. When plotting categorical data, arrange categories in a meaningful order (e.g., by size, time, or frequency) rather than randomly.
79
- 9. Ensure that the categorical data are plotted on the x-axis, and the frequencies (numerical data) are plotted on the y-axis.
80
- 10. Use seaborn
81
- 11. In the end you have to return a dict which contain final fig as fig key, Generated SQL as sql key, Data as a dataframe with data key using the `final_answer` tool e.g. final_answer(answer={"fig": fig, "sql": sql, "data": data})''',
82
-
83
- task= f'{question}',
84
  schema= f'{schema}',
85
  table_name= f'{table_name}',
86
  )
@@ -96,6 +83,7 @@ class SQLExecutorTool(Tool):
96
  "description": f"The query to perform. This should be correct DuckDB SQL.",
97
  }
98
  }
 
99
  output_type = "pandas.core.frame.DataFrame"
100
 
101
  def forward(self, query: str) -> str:
@@ -110,8 +98,6 @@ def main(table, text_query):
110
  ax.set_axis_off()
111
 
112
  schema, table_name = get_table_schema(table)
113
- tool.description = f"""Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result."""
114
-
115
 
116
  try:
117
  output = get_visualization(question=text_query, tool=tool, schema=schema, table_name=table_name)
@@ -121,7 +107,6 @@ def main(table, text_query):
121
  except Exception as e:
122
  gr.Warning(f"❌ Unable to generate the visualization. {e}")
123
 
124
-
125
  return fig, generated_sql, data
126
 
127
 
 
4
  import matplotlib.pyplot as plt
5
  from transformers import HfEngine, ReactCodeAgent
6
  from transformers.agents import Tool
7
+ from prompt import PROMPT
8
 
9
  # Height of the Tabs Text Area
10
  TAB_LINES = 8
 
65
  'pandas', 'plotly.express',
66
  'seaborn'], max_iterations=10)
67
  fig = agent.run(
68
+
69
+ task=PROMPT,
70
+ query_description= f'{question}',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  schema= f'{schema}',
72
  table_name= f'{table_name}',
73
  )
 
83
  "description": f"The query to perform. This should be correct DuckDB SQL.",
84
  }
85
  }
86
+ description = """Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result."""
87
  output_type = "pandas.core.frame.DataFrame"
88
 
89
  def forward(self, query: str) -> str:
 
98
  ax.set_axis_off()
99
 
100
  schema, table_name = get_table_schema(table)
 
 
101
 
102
  try:
103
  output = get_visualization(question=text_query, tool=tool, schema=schema, table_name=table_name)
 
107
  except Exception as e:
108
  gr.Warning(f"❌ Unable to generate the visualization. {e}")
109
 
 
110
  return fig, generated_sql, data
111
 
112
 
prompt.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT = '''
2
+ Here are the steps you should follow while writing code for Visualization:
3
+ 1. You have access to the database with the `sql_engine` tool, which allows you to run DuckDB SQL queries and return results as a df.
4
+ 2. Query the database using `sql_engine`, print the first 5 rows to inspect the data.
5
+ 3. Select the most appropriate chart type for the data:
6
+ - Use bar charts for categorical comparisons, line charts for trends over time, scatter plots for relationships between variables, pie charts for proportions, histograms for distribution, and box plots for data spread and outliers.
7
+ 4. Analyze the data and choose the best visualization type to answer the query.
8
+ 5. Always include a plot in your answer.
9
+ 6. Use Seaborn for the plots.
10
+ 7. In the end, return a dictionary containing the final figure (`fig` key), the generated SQL (`sql` key), and the data as a dataframe (`data` key) using the `final_answer` tool, e.g. `final_answer(answer={{"fig": 'fig.png', "sql": sql, "data": data}})`.
11
+
12
+ Example:
13
+
14
+ ```python
15
+ # Input query
16
+ query_description = 'Average tip amount based on the ride time length in minutes.'
17
+
18
+ # SQL Query to get ride time length and average tip amount
19
+ query = """
20
+ SELECT
21
+ EXTRACT(EPOCH FROM (tpep_dropoff_datetime - tpep_pickup_datetime)) / 60 AS ride_time_length,
22
+ AVG(tip_amount) AS avg_tip_amount
23
+ FROM
24
+ sample_data.nyc.taxi
25
+ GROUP BY
26
+ EXTRACT(EPOCH FROM (tpep_dropoff_datetime - tpep_pickup_datetime)) / 60
27
+ """
28
+
29
+ # Execute the query using the sql_engine tool
30
+ df = sql_engine(query=query)
31
+
32
+ # Print the result to observe the data
33
+ print(df)
34
+
35
+ # Create a line plot using seaborn
36
+ import seaborn as sns
37
+ import matplotlib.pyplot as plt
38
+
39
+ plt.figure(figsize=(10,6))
40
+ sns.lineplot(x="ride_time_length", y="avg_tip_amount", data=df)
41
+
42
+ # Set the title and labels
43
+ plt.title("Average Tip Amount vs Ride Time Length")
44
+ plt.xlabel("Ride Time Length (minutes)")
45
+ plt.ylabel("Average Tip Amount")
46
+
47
+ # Print the plot to observe the results
48
+ print("Plot created")
49
+
50
+ # Since we are required to return a fig, sql, and data, let's store the plot in a variable
51
+ fig = plt.gcf()
52
+
53
+ # Store the query in a variable
54
+ sql = query
55
+
56
+ # Store the dataframe in a variable
57
+ data = df
58
+
59
+ # Return the final answer
60
+ final_answer(answer={"fig": fig, "sql": sql, "data": data})
61
+ ```
62
+ '''