Mustehson commited on
Commit
1d1ec23
Β·
1 Parent(s): aad99a8

Refactored & Langsmith Prompt

Browse files
Files changed (2) hide show
  1. app.py +43 -96
  2. requirements.txt +2 -1
app.py CHANGED
@@ -5,16 +5,19 @@ import matplotlib.pyplot as plt
5
  from transformers import HfEngine, ReactCodeAgent
6
  from transformers.agents import Tool
7
  from langsmith import traceable
 
 
 
8
  # Height of the Tabs Text Area
9
  TAB_LINES = 8
10
- # Load Token
11
- md_token = os.getenv('MD_TOKEN')
12
- os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
13
 
14
- print('Connecting to DB...')
15
- # Connect to DB
 
16
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
 
17
 
 
18
  models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct",
19
  "meta-llama/Llama-3.1-70B-Instruct"]
20
 
@@ -31,7 +34,14 @@ for model in models:
31
 
32
  if not model_loaded:
33
  gr.Warning(f"❌ None of the model form {models} are available. {e}")
 
 
 
 
 
 
34
 
 
35
  def get_schemas():
36
  schemas = conn.execute("""
37
  SELECT DISTINCT schema_name
@@ -65,92 +75,6 @@ def get_table_schema(table):
65
  return ddl_create, full_path
66
 
67
 
68
- def get_visualization(question, tool, schema, table_name):
69
- agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True,
70
- additional_authorized_imports=['matplotlib.pyplot',
71
- 'pandas', 'plotly.express',
72
- 'seaborn'], max_iterations=10)
73
- results = agent.run(
74
-
75
- task= f'''
76
-
77
- Here are the steps you should follow while writing code for Visualization:
78
- 1. You have access to the database with the `sql_engine` tool, which allows you to run DuckDB SQL queries and return results as a df.
79
- 2. Query the database using `sql_engine`, print the first 5 rows to inspect the data.
80
- 3. Select the most appropriate chart type for the data:
81
- - Use bar charts for categorical comparisons, line charts for trends over time, scatter plots for relationships between variables, pie charts for proportions, histograms for distribution, and box plots for data spread and outliers.
82
- 4. Analyze the data and choose the best visualization type to answer the query.
83
- 5. Always include a plot in your answer.
84
- 6. Use Seaborn for the plots.
85
- 7. In the end, return a dictionary containing the final figure (`fig` key), the generated SQL (`sql` key), and the data as a dataframe (`data` key) using the `final_answer` tool, e.g. `final_answer(answer={{"fig": 'fig.png', "sql": sql, "data": data}})`.
86
-
87
- Example:
88
-
89
- ```python
90
- # Input query
91
- query_description = 'Average tip amount based on the ride time length in minutes.'
92
-
93
- # SQL Query to get ride time length and average tip amount
94
- query = """
95
- SELECT
96
- EXTRACT(EPOCH FROM (tpep_dropoff_datetime - tpep_pickup_datetime)) / 60 AS ride_time_length,
97
- AVG(tip_amount) AS avg_tip_amount
98
- FROM
99
- sample_data.nyc.taxi
100
- GROUP BY
101
- EXTRACT(EPOCH FROM (tpep_dropoff_datetime - tpep_pickup_datetime)) / 60
102
- """
103
-
104
- # Execute the query using the sql_engine tool
105
- df = sql_engine(query=query)
106
-
107
- # Print the result to observe the data
108
- print(df)
109
-
110
- # Create a line plot using seaborn
111
- import seaborn as sns
112
- import matplotlib.pyplot as plt
113
-
114
- plt.figure(figsize=(10,6))
115
- sns.lineplot(x="ride_time_length", y="avg_tip_amount", data=df)
116
-
117
- # Set the title and labels
118
- plt.title("Average Tip Amount vs Ride Time Length")
119
- plt.xlabel("Ride Time Length (minutes)")
120
- plt.ylabel("Average Tip Amount")
121
-
122
- # Print the plot to observe the results
123
- print("Plot created")
124
-
125
- # Since we are required to return a fig, sql, and data, let's store the plot in a variable
126
- fig = plt.gcf()
127
-
128
- # Store the query in a variable
129
- sql = query
130
-
131
- # Store the dataframe in a variable
132
- data = df
133
-
134
- # Return the final answer
135
- final_answer(answer={{"fig": fig, "sql": sql, "data": data}})
136
- ```
137
-
138
-
139
- Here is the query you should generate a plot for: '{question}'.
140
- Here is the schema: '{schema}' and here is the table name: '{table_name}
141
-
142
-
143
- '''
144
- )
145
-
146
- return results
147
-
148
-
149
- @traceable()
150
- def query_response(input_prompt, generated_sql):
151
- return generated_sql
152
-
153
-
154
  class SQLExecutorTool(Tool):
155
  name = "sql_engine"
156
  inputs = {
@@ -168,6 +92,27 @@ class SQLExecutorTool(Tool):
168
 
169
  tool = SQLExecutorTool()
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def main(table, text_query):
172
  # Empty Fig
173
  fig, ax = plt.subplots()
@@ -175,14 +120,12 @@ def main(table, text_query):
175
 
176
  schema, table_name = get_table_schema(table)
177
 
178
-
179
  try:
180
- output = get_visualization(question=text_query, tool=tool, schema=schema, table_name=table_name)
181
  fig = output.get('fig', None)
182
  generated_sql = output.get('sql', None)
183
  data = output.get('data', None)
184
- input_prompt = text_query + '\n' + table_name + '\n' + schema
185
- _ = query_response(input_prompt, generated_sql)
186
  except Exception as e:
187
  gr.Warning(f"❌ Unable to generate the visualization. {e}")
188
 
@@ -246,4 +189,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
246
  generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot, generated_sql, data])
247
 
248
  if __name__ == "__main__":
249
- demo.launch(debug=True)
 
 
 
 
 
5
  from transformers import HfEngine, ReactCodeAgent
6
  from transformers.agents import Tool
7
  from langsmith import traceable
8
+ from langchain import hub
9
+
10
+
11
  # Height of the Tabs Text Area
12
  TAB_LINES = 8
 
 
 
13
 
14
+
15
+ #----------CONNECT TO DATABASE----------
16
+ md_token = os.getenv('MD_TOKEN')
17
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
18
+ #---------------------------------------
19
 
20
+ #-------LOAD HUGGINGFACE MODEL-------
21
  models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct",
22
  "meta-llama/Llama-3.1-70B-Instruct"]
23
 
 
34
 
35
  if not model_loaded:
36
  gr.Warning(f"❌ None of the model form {models} are available. {e}")
37
+ #---------------------------------------
38
+
39
+ #-----LOAD PROMPT FROM LANCHAIN HUB-----
40
+ prompt = hub.pull("viz-prompt")
41
+ #-------------------------------------
42
+
43
 
44
+ #--------------ALL UTILS----------------
45
  def get_schemas():
46
  schemas = conn.execute("""
47
  SELECT DISTINCT schema_name
 
75
  return ddl_create, full_path
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  class SQLExecutorTool(Tool):
79
  name = "sql_engine"
80
  inputs = {
 
92
 
93
  tool = SQLExecutorTool()
94
 
95
+ def process_outputs(output) :
96
+ if 'data' in output:
97
+ output['data'] = "<DataFrame is hidden>"
98
+ if 'fig' in output:
99
+ output['fig'] = "<Figure is hidden>"
100
+ return output
101
+
102
+ @traceable(process_outputs=process_outputs)
103
+ def get_visualization(question, schema, table_name):
104
+ agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True,
105
+ additional_authorized_imports=['matplotlib.pyplot',
106
+ 'pandas', 'plotly.express',
107
+ 'seaborn'], max_iterations=10)
108
+ results = agent.run(
109
+ task= prompt.format(question=question, schema=schema, table_name=table_name)
110
+ )
111
+
112
+ return results
113
+ #---------------------------------------
114
+
115
+
116
  def main(table, text_query):
117
  # Empty Fig
118
  fig, ax = plt.subplots()
 
120
 
121
  schema, table_name = get_table_schema(table)
122
 
 
123
  try:
124
+ output = get_visualization(question=text_query, schema=schema, table_name=table_name)
125
  fig = output.get('fig', None)
126
  generated_sql = output.get('sql', None)
127
  data = output.get('data', None)
128
+
 
129
  except Exception as e:
130
  gr.Warning(f"❌ Unable to generate the visualization. {e}")
131
 
 
189
  generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot, generated_sql, data])
190
 
191
  if __name__ == "__main__":
192
+ demo.launch(debug=True)
193
+
194
+
195
+
196
+
requirements.txt CHANGED
@@ -5,4 +5,5 @@ huggingface_hub
5
  accelerate==0.34.2
6
  transformers==4.44.2
7
  duckdb==1.1.1
8
- langsmith==0.1.135
 
 
5
  accelerate==0.34.2
6
  transformers==4.44.2
7
  duckdb==1.1.1
8
+ langsmith==0.1.135
9
+ langchain==0.3.4