Mustehson commited on
Commit
9a42f0f
·
1 Parent(s): fb83515

Agent Execution

Browse files
Files changed (4) hide show
  1. app.py +49 -260
  2. plot_utils.py +0 -85
  3. requirements.txt +1 -6
  4. visualization_prompt.py +0 -154
app.py CHANGED
@@ -1,16 +1,8 @@
1
  import os
2
- import json
3
- import torch
4
  import duckdb
5
- import spaces
6
- import pandas as pd
7
  import gradio as gr
8
- import matplotlib.pyplot as plt
9
- from visualization_prompt import graph_instructions
10
- from langchain_core.prompts import ChatPromptTemplate
11
- from langchain_huggingface.llms import HuggingFacePipeline
12
- from plot_utils import plot_bar_chart, plot_horizontal_bar_chart, plot_line_graph, plot_scatter, plot_pie
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
14
 
15
  # Height of the Tabs Text Area
16
  TAB_LINES = 8
@@ -22,32 +14,7 @@ print('Connecting to DB...')
22
  # Connect to DB
23
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
24
 
25
- if torch.cuda.is_available():
26
- device = torch.device("cuda")
27
- print(f"Using GPU: {torch.cuda.get_device_name(device)}")
28
- else:
29
- device = torch.device("cpu")
30
- print("Using CPU")
31
-
32
- print('Loading Model...')
33
-
34
-
35
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
36
-
37
- quantization_config = BitsAndBytesConfig(
38
- load_in_4bit=True,
39
- bnb_4bit_compute_dtype=torch.bfloat16,
40
- bnb_4bit_use_double_quant=True,
41
- bnb_4bit_quant_type= "nf4")
42
-
43
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", quantization_config=quantization_config,
44
- device_map="auto", torch_dtype=torch.bfloat16)
45
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
46
- llm = HuggingFacePipeline(pipeline=pipe)
47
-
48
-
49
- print('Model Loaded...')
50
- print(f'Model Device: {model.device}')
51
 
52
  def get_schemas():
53
  schemas = conn.execute("""
@@ -79,225 +46,54 @@ def get_table_schema(table):
79
  else:
80
  old_path = table
81
  ddl_create = ddl_create.replace(old_path, full_path)
82
- return ddl_create
83
-
84
- # Get Prompt
85
- def get_prompt(schema, query_input):
86
- text = f"""
87
- ### Instruction:
88
- Your task is to generate a valid DuckDB SQL query to answer the following question. Select only the columns that are necessary for visualization or that will make visualization easier (e.g., numeric or categorical data). Respond only with the SQL query and do not include anything else.
89
-
90
- ### Input:
91
- Here is the database schema that the SQL query will run on:
92
- {schema}
 
 
 
 
 
93
 
94
- ### Question:
95
- {query_input}
96
-
97
- ### Response (use DuckDB shorthand if possible):
98
- """
99
- return text
100
-
101
- @spaces.GPU(duration=60)
102
- def run_llm(prompt):
103
- result = llm.invoke(prompt)
104
- return result
105
-
106
-
107
- def get_visualization_type(text_query, sql_query, sql_result):
108
- system_message = '''
109
- You are an AI assistant that recommends appropriate data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable type of graph or chart to visualize the data. If no visualization is appropriate, indicate that.
110
-
111
- Available chart types and their use cases:
112
- - Bar Graphs: Best for comparing categorical data or showing changes over time when categories are discrete and the number of categories is more than 2. Use for questions like "What are the sales figures for each product?" or "How does the population of cities compare? or "What percentage of each city is male?"
113
- - Horizontal Bar Graphs: Best for comparing categorical data or showing changes over time when the number of categories is small or the disparity between categories is large. Use for questions like "Show the revenue of A and B?" or "How does the population of 2 cities compare?" or "How many men and women got promoted?" or "What percentage of men and what percentage of women got promoted?" when the disparity between categories is large.
114
- - Scatter Plots: Useful for identifying relationships or correlations between two numerical variables or plotting distributions of data. Best used when both x axis and y axis are continuous. Use for questions like "Plot a distribution of the fares (where the x axis is the fare and the y axis is the count of people who paid that fare)" or "Is there a relationship between advertising spend and sales?" or "How do height and weight correlate in the dataset? Do not use it for questions that do not have a continuous x axis."
115
- - Pie Charts: Ideal for showing proportions or percentages within a whole. Use for questions like "What is the market share distribution among different companies?" or "What percentage of the total revenue comes from each product?"
116
- - Line Graphs: Best for showing trends and distributionsover time. Best used when both x axis and y axis are continuous. Used for questions like "How have website visits changed over the year?" or "What is the trend in temperature over the past decade?". Do not use it for questions that do not have a continuous x axis or a time based x axis.
117
-
118
- Consider these types of questions when recommending a visualization:
119
- 1. Aggregations and Summarizations (e.g., "What is the average revenue by month?" - Line Graph)
120
- 2. Comparisons (e.g., "Compare the sales figures of Product A and Product B over the last year." - Line or Column Graph)
121
- 3. Plotting Distributions (e.g., "Plot a distribution of the age of users" - Scatter Plot)
122
- 4. Trends Over Time (e.g., "What is the trend in the number of active users over the past year?" - Line Graph)
123
- 5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
124
- 6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
125
-
126
- Generate a JSON object. The JSON object should have the following structure:
127
-
128
- Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
129
- Reason: [Brief explanation for your recommendation]
130
-
131
- '''
132
- human_message = '''
133
- User question: {question}
134
- SQL query: {sql_query}
135
- Query results: {results}
136
-
137
- Recommend a visualization:
138
- '''
139
-
140
- prompt = ChatPromptTemplate.from_messages([
141
- ("system", system_message),
142
- ("human", human_message),
143
- ])
144
-
145
- final_prompt = prompt.format_prompt(question=text_query,
146
- sql_query=sql_query, results=sql_result)
147
- response = run_llm(final_prompt)
148
- response = response.replace('```', '')
149
- response = response.replace('json', '')
150
- json_data = json.loads(response)
151
- visualization = json_data['Recommended Visualization']
152
- reason = json_data['Reason']
153
- print(visualization, reason)
154
- return visualization, reason
155
-
156
- def format_bar_data(results, question, sql_query):
157
- if isinstance(results, str):
158
- results = eval(results)
159
-
160
- if len(results[0]) == 2:
161
- labels = [str(row[0]) for row in results]
162
- data = [float(row[1]) if row[1] is not None else 0 for row in results]
163
-
164
- prompt = ChatPromptTemplate.from_messages([
165
- ("system", "You are a data labeling expert. Given a question, SQL Query used, and some data, provide a concise and relevant label for the data series."),
166
- ("human", "Question: {question}\n SQL Query: {sql_query} \nData (first few rows): {data}\n\n Provide a concise label name for this y axis. Just give me the name."),
167
- ])
168
- prompt = prompt.format_prompt(question=question, data=str(results[:2]),
169
- sql_query=sql_query)
170
- label = run_llm(prompt)
171
- label = label.replace('**Answer:**', '')
172
- values = [{"data": data, "label": label}]
173
- elif len(results[0]) == 3:
174
- categories = set(row[1] for row in results)
175
- labels = list(categories)
176
- entities = set(row[0] for row in results)
177
- values = []
178
- for entity in entities:
179
- entity_data = [float(row[2]) for row in results if row[0] == entity]
180
- values.append({"data": entity_data, "label": str(entity)})
181
- else:
182
- raise ValueError("Unexpected data format in results")
183
-
184
- formatted_data = {
185
- "labels": labels,
186
- "values": values
187
  }
 
188
 
189
- return formatted_data
190
-
191
- def format_data(text_query, sql_query, sql_result, visualization_type):
192
- instruction = graph_instructions[visualization_type]
193
-
194
- template = ChatPromptTemplate.from_messages([
195
- ("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user,Given the question, SQL query, and results, your job is to Understand the question and SQL query. Ensure the result matches the query and can be easily visualized."),
196
- ("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
197
- ])
198
-
199
-
200
- prompt = template.format_prompt(question=text_query, sql_query=sql_query,
201
- results=sql_result, instructions=instruction)
202
- formatted_data = run_llm(prompt)
203
- formatted_data = formatted_data.replace('```', '')
204
- formatted_data = formatted_data.replace('json', '')
205
- print(f'Formatted Data {formatted_data}')
206
- return json.loads(formatted_data.replace('.', '').strip())
207
-
208
- def visualize_result(text_query, visualization_type, sql_query,
209
- sql_result):
210
-
211
- if visualization_type == 'bar':
212
- try:
213
- data = format_bar_data(sql_result, text_query, sql_query)
214
- except Exception as e:
215
- data = format_data(text_query=text_query, sql_query=sql_query,
216
- sql_result=sql_result, visualization_type=visualization_type)
217
- return plot_bar_chart(data)
218
-
219
- elif visualization_type == 'horizontal_bar':
220
- data = format_data(text_query=text_query, sql_query=sql_query,
221
- sql_result=sql_result, visualization_type=visualization_type)
222
- return plot_horizontal_bar_chart(data)
223
-
224
- elif visualization_type == 'line':
225
- data = format_data(text_query=text_query, sql_query=sql_query,
226
- sql_result=sql_result, visualization_type=visualization_type)
227
- return plot_line_graph(data)
228
-
229
- elif visualization_type == 'pie':
230
- data = format_data(text_query=text_query, sql_query=sql_query,
231
- sql_result=sql_result, visualization_type=visualization_type)
232
- return plot_pie(data)
233
-
234
- elif visualization_type == 'scatter':
235
- data = format_data(text_query=text_query, sql_query=sql_query,
236
- sql_result=sql_result, visualization_type=visualization_type)
237
- return plot_scatter(data)
238
-
239
- elif visualization_type == 'none':
240
- fig, ax = plt.subplots()
241
- ax.set_visible(False)
242
- return fig
243
-
244
-
245
 
246
  def main(table, text_query):
247
- if table is None:
248
- return ["", "", "", pd.DataFrame([{"error": "❌ Table is None."}])]
249
- fig, ax = plt.subplots()
250
- ax.set_visible(False)
251
-
252
- schema = get_table_schema(table)
253
- prompt = get_prompt(schema, text_query)
254
- try:
255
- generated_sql_query = run_llm(prompt)
256
- print(f'Generated SQL Query: {generated_sql_query}')
257
- except Exception as e:
258
- return generate_output(schema, prompt, '', fig, pd.DataFrame([{"error": f"❌ Unable to generate the SQL query. {e}"}]))
259
-
260
- try:
261
- sql_query_result_raw = conn.sql(generated_sql_query)
262
- sql_query_result = sql_query_result_raw.fetchall()
263
- sql_query_df = sql_query_result_raw.df()
264
-
265
- print(f"SQL Query Result: {sql_query_result}")
266
- except Exception as e:
267
- return generate_output(schema, prompt, generated_sql_query, fig, pd.DataFrame([{"error": f"❌ Unable to execute the SQL query. {e}"}]))
268
-
269
- if len(sql_query_result) >= 100:
270
- gr.Warning(f"⚠️ Data is too large for visualization. Please refine your query.")
271
- return generate_output(schema, prompt, generated_sql_query,
272
- fig, sql_query_df)
273
 
274
  try:
275
- visualization_type, reason = get_visualization_type(text_query=text_query,
276
- sql_query=generated_sql_query, sql_result=sql_query_result)
277
-
278
- print(f"Visualization Type: {visualization_type}")
279
- print(f"Reason: {reason}")
280
-
281
  except Exception as e:
282
- return generate_output(schema, prompt, generated_sql_query,
283
- fig, sql_query_df)
 
284
 
285
- if visualization_type != 'none':
286
- try:
287
- plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
288
- sql_result=sql_query_result, visualization_type=visualization_type)
289
-
290
- except Exception as e:
291
- gr.Warning(f"⚠️ {e}")
292
- return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
293
- else:
294
- gr.Warning(f"⚠️ {reason}")
295
- return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
296
-
297
- return generate_output(schema, prompt, generated_sql_query, plot, sql_query_df)
298
-
299
- def generate_output(schema, prompt, generated_sql_query, result_plot, sql_query_df):
300
- return [schema, prompt, generated_sql_query, result_plot, sql_query_df]
301
 
302
  custom_css = """
303
  .gradio-container {
@@ -312,6 +108,7 @@ custom_css = """
312
  background-color: #4a90e2 !important;
313
  }
314
  .gr-button:hover {
 
315
  background-color: #3a7bc8 !important;
316
  }
317
  """
@@ -329,32 +126,24 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
329
 
330
  with gr.Row():
331
 
332
- with gr.Column(scale=1, variant='panel'):
333
  schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
334
  tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
335
 
336
  with gr.Column(scale=2):
337
- query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
338
  with gr.Row():
339
  with gr.Column(scale=7):
340
  pass
341
  with gr.Column(scale=1):
342
  generate_query_button = gr.Button("Run Query", variant="primary")
343
 
344
- with gr.Tabs():
345
- with gr.Tab("Result"):
346
- query_result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
347
- with gr.Tab("Plot"):
348
- result_plot = gr.Plot()
349
- with gr.Tab("SQL Query"):
350
- generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
351
- with gr.Tab("Prompt"):
352
- input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
353
- with gr.Tab("Schema"):
354
- table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
355
 
356
  schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
357
- generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_plot, query_result_output])
358
 
359
  if __name__ == "__main__":
360
  demo.launch(debug=True)
 
1
  import os
 
 
2
  import duckdb
 
 
3
  import gradio as gr
4
+ from transformers import HfEngine, ReactCodeAgent
5
+ from transformers.agents import Tool
 
 
 
 
6
 
7
  # Height of the Tabs Text Area
8
  TAB_LINES = 8
 
14
  # Connect to DB
15
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
16
 
17
+ llm_engine = HfEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def get_schemas():
20
  schemas = conn.execute("""
 
46
  else:
47
  old_path = table
48
  ddl_create = ddl_create.replace(old_path, full_path)
49
+ return ddl_create, full_path
50
+
51
+ def get_visualization(question, tool):
52
+ agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True,
53
+ additional_authorized_imports=['matplotlib.pyplot',
54
+ 'pandas', 'plotly.express',
55
+ 'seaborn'], max_iterations=20)
56
+ fig = agent.run(
57
+ task=f'''
58
+ Use seaborn. Always
59
+ Question: {question}
60
+ Always use the right colors.
61
+ If the question is about showing n number of rows return empty figure.
62
+ In the end you have to return a final fig using the `final_answer` tool
63
+ ''',
64
+ )
65
 
66
+ return fig
67
+
68
+ class SQLExecutorTool(Tool):
69
+ name = "sql_engine"
70
+ inputs = {
71
+ "query": {
72
+ "type": "text",
73
+ "description": f"The query to perform. This should be correct DuckDB SQL.",
74
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  }
76
+ output_type = "pandas.core.frame.DataFrame"
77
 
78
+ def forward(self, query: str) -> str:
79
+ with duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) as con:
80
+ output_df = conn.sql(query).df()
81
+ return output_df
82
+
83
+ tool = SQLExecutorTool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def main(table, text_query):
86
+ schema, _ = get_table_schema(table)
87
+ tool.description = f"""Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result.
88
+ The table schema is as follows: \n{schema}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  try:
91
+ fig = get_visualization(question=text_query, tool=tool)
 
 
 
 
 
92
  except Exception as e:
93
+ gr.Warning(f"❌ Unable to generate the visualization. {e}")
94
+
95
+ return fig
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  custom_css = """
99
  .gradio-container {
 
108
  background-color: #4a90e2 !important;
109
  }
110
  .gr-button:hover {
111
+
112
  background-color: #3a7bc8 !important;
113
  }
114
  """
 
126
 
127
  with gr.Row():
128
 
129
+ with gr.Column(scale=1):
130
  schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
131
  tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
132
 
133
  with gr.Column(scale=2):
134
+ query_input = gr.Textbox(lines=3, label="Text Query", placeholder="Enter your text query here...")
135
  with gr.Row():
136
  with gr.Column(scale=7):
137
  pass
138
  with gr.Column(scale=1):
139
  generate_query_button = gr.Button("Run Query", variant="primary")
140
 
141
+ with gr.Tabs():
142
+ with gr.Tab("Plot"):
143
+ result_plot = gr.Plot()
 
 
 
 
 
 
 
 
144
 
145
  schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
146
+ generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot])
147
 
148
  if __name__ == "__main__":
149
  demo.launch(debug=True)
plot_utils.py DELETED
@@ -1,85 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import numpy as np
3
-
4
- def plot_bar_chart(data):
5
- fig, ax = plt.subplots()
6
- labels = data['labels']
7
- values = data['values']
8
-
9
- colors = ['darkviolet', 'indigo', 'blueviolet']
10
- width = 0.2
11
- x = np.arange(len(labels))
12
-
13
- for i, value in enumerate(values):
14
- ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])
15
-
16
- ax.set_xticks(x + width / 2 * (len(values) - 1))
17
- ax.set_xticklabels(labels)
18
-
19
- ax.set_title('Bar Chart')
20
- ax.set_xlabel('Labels')
21
- ax.set_ylabel('Values')
22
- ax.legend()
23
- return fig
24
-
25
- def plot_horizontal_bar_chart(data):
26
- fig, ax = plt.subplots()
27
- labels = data['labels']
28
- values = data['values']
29
-
30
- for value in values:
31
- ax.barh(labels, value['data'], label=value['label'], color=['darkviolet', 'indigo', 'blueviolet'])
32
-
33
- ax.set_title('Horizontal Bar Chart')
34
- ax.set_xlabel('Values')
35
- ax.set_ylabel('Labels')
36
- ax.legend()
37
- return fig
38
-
39
- def plot_line_graph(data):
40
- fig, ax = plt.subplots()
41
- x_values = data['xValues']
42
- y_values_list = data['yValues']
43
-
44
- for y_values in y_values_list:
45
- label = y_values.get('label', None)
46
- ax.plot(x_values, y_values['data'], label=label)
47
-
48
- ax.set_title('Line Graph')
49
- ax.set_xlabel('X Values')
50
- ax.set_ylabel('Y Values')
51
-
52
- if any('label' in y_values for y_values in y_values_list):
53
- ax.legend()
54
-
55
- return fig
56
-
57
- def plot_scatter(data):
58
- fig, ax = plt.subplots()
59
- for series in data['series']:
60
- x_values = [point['x'] for point in series['data']]
61
- y_values = [point['y'] for point in series['data']]
62
- label = series.get('label', None)
63
- ax.scatter(x_values, y_values, label=label)
64
-
65
- ax.set_title('Scatter Plot')
66
- ax.set_xlabel('X Values')
67
- ax.set_ylabel('Y Values')
68
-
69
- if any('label' in series for series in data['series']):
70
- ax.legend()
71
-
72
- return fig
73
-
74
- def plot_pie(data):
75
- values = [item['value'] for item in data]
76
- labels = [item['label'] for item in data]
77
-
78
- fig, ax = plt.subplots()
79
- wedges, texts, autotexts = ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=140)
80
- ax.set_title('Pie Chart')
81
-
82
- for text in texts + autotexts:
83
- text.set_fontsize(10)
84
-
85
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,3 @@
1
  huggingface_hub
2
- accelerate
3
- bitsandbytes
4
  transformers
5
- duckdb
6
- langchain-huggingface
7
- langchain-core
8
- sentencepiece
 
1
  huggingface_hub
 
 
2
  transformers
3
+ duckdb
 
 
 
visualization_prompt.py DELETED
@@ -1,154 +0,0 @@
1
- barGraphIntstruction = '''
2
-
3
- Where data is: {
4
- labels: string[]
5
- values: {\data: number[], label: string}[]
6
- }
7
-
8
- The output must follow this format strictly, even if the input data differs from the examples below.
9
- // Examples of usage:
10
- Each label represents a column on the x axis.
11
- Each array in values represents a different entity.
12
-
13
- Here we are looking at average income for each month.
14
- {
15
- labels: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
16
- values: [{data:[21.5, 25.0, 47.5, 64.8, 105.5, 133.2], label: 'Income'}],
17
- }
18
-
19
- Here we are looking at the performance of american and european players for each series. Since there are two entities, we have two arrays in values.
20
- {
21
- labels: ['series A', 'series B', 'series C'],
22
- values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
23
- }
24
-
25
- The output format must be consistent with this structure, regardless of the specific input data.
26
- '''
27
-
28
- horizontalBarGraphIntstruction = '''
29
-
30
- Where data is: {
31
- labels: string[]
32
- values: {\data: number[], label: string}[]
33
- }
34
-
35
- // Examples of usage:
36
- Each label represents a column on the x axis.
37
- Each array in values represents a different entity.
38
-
39
- Here we are looking at average income for each month.
40
- {
41
- labels: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
42
- values: [{data:[21.5, 25.0, 47.5, 64.8, 105.5, 133.2], label: 'Income'}],
43
- }
44
-
45
- Here we are looking at the performance of american and european players for each series. Since there are two entities, we have two arrays in values.
46
- {
47
- labels: ['series A', 'series B', 'series C'],
48
- values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
49
- }
50
-
51
- '''
52
-
53
-
54
- lineGraphIntstruction = '''
55
-
56
- Where data is: {
57
- xValues: number[] | string[]
58
- yValues: { data: number[]; label: string }[]
59
- }
60
-
61
- // Examples of usage:
62
-
63
- Here we are looking at the momentum of a body as a function of mass.
64
- {
65
- xValues: ['2020', '2021', '2022', '2023', '2024'],
66
- yValues: [
67
- { data: [2, 5.5, 2, 8.5, 1.5]},
68
- ],
69
- }
70
-
71
- Here we are looking at the performance of american and european players for each year. Since there are two entities, we have two arrays in yValues.
72
- {
73
- xValues: ['2020', '2021', '2022', '2023', '2024'],
74
- yValues: [
75
- { data: [2, 5.5, 2, 8.5, 1.5], label: 'American' },
76
- { data: [2, 5.5, 2, 8.5, 1.5], label: 'European' },
77
- ],
78
- }
79
- '''
80
-
81
- pieChartIntstruction = '''
82
-
83
- Where data is: {
84
- labels: string
85
- values: number
86
- }[]
87
-
88
- // Example usage:
89
- [
90
- { id: 0, value: 10, label: 'series A' },
91
- { id: 1, value: 15, label: 'series B' },
92
- { id: 2, value: 20, label: 'series C' },
93
- ],
94
- '''
95
-
96
- scatterPlotIntstruction = '''
97
- Where data is: {
98
- series: {
99
- data: { x: number; y: number; id: number }[]
100
- label: string
101
- }[]
102
- }
103
-
104
- // Examples of usage:
105
- 1. Here each data array represents the points for a different entity.
106
- We are looking for correlation between amount spent and quantity bought for men and women.
107
- {
108
- series: [
109
- {
110
- data: [
111
- { x: 100, y: 200, id: 1 },
112
- { x: 120, y: 100, id: 2 },
113
- { x: 170, y: 300, id: 3 },
114
- ],
115
- label: 'Men',
116
- },
117
- {
118
- data: [
119
- { x: 300, y: 300, id: 1 },
120
- { x: 400, y: 500, id: 2 },
121
- { x: 200, y: 700, id: 3 },
122
- ],
123
- label: 'Women',
124
- }
125
- ],
126
- }
127
-
128
- 2. Here we are looking for correlation between the height and weight of players.
129
- {
130
- series: [
131
- {
132
- data: [
133
- { x: 180, y: 80, id: 1 },
134
- { x: 170, y: 70, id: 2 },
135
- { x: 160, y: 60, id: 3 },
136
- ],
137
- label: 'Players',
138
- },
139
- ],
140
- }
141
-
142
- // Note: Each object in the 'data' array represents a point on the scatter plot.
143
- // The 'x' and 'y' values determine the position of the point, and 'id' is a unique identifier.
144
- // Multiple series can be represented, each as an object in the outer array.
145
- '''
146
-
147
-
148
- graph_instructions = {
149
- "bar": barGraphIntstruction,
150
- "horizontal_bar": horizontalBarGraphIntstruction,
151
- "line": lineGraphIntstruction,
152
- "pie": pieChartIntstruction,
153
- "scatter": scatterPlotIntstruction
154
- }