Spaces:
Running
Running
Mustehson
commited on
Commit
·
9a42f0f
1
Parent(s):
fb83515
Agent Execution
Browse files- app.py +49 -260
- plot_utils.py +0 -85
- requirements.txt +1 -6
- visualization_prompt.py +0 -154
app.py
CHANGED
@@ -1,16 +1,8 @@
|
|
1 |
import os
|
2 |
-
import json
|
3 |
-
import torch
|
4 |
import duckdb
|
5 |
-
import spaces
|
6 |
-
import pandas as pd
|
7 |
import gradio as gr
|
8 |
-
import
|
9 |
-
from
|
10 |
-
from langchain_core.prompts import ChatPromptTemplate
|
11 |
-
from langchain_huggingface.llms import HuggingFacePipeline
|
12 |
-
from plot_utils import plot_bar_chart, plot_horizontal_bar_chart, plot_line_graph, plot_scatter, plot_pie
|
13 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
14 |
|
15 |
# Height of the Tabs Text Area
|
16 |
TAB_LINES = 8
|
@@ -22,32 +14,7 @@ print('Connecting to DB...')
|
|
22 |
# Connect to DB
|
23 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
24 |
|
25 |
-
|
26 |
-
device = torch.device("cuda")
|
27 |
-
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
|
28 |
-
else:
|
29 |
-
device = torch.device("cpu")
|
30 |
-
print("Using CPU")
|
31 |
-
|
32 |
-
print('Loading Model...')
|
33 |
-
|
34 |
-
|
35 |
-
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
|
36 |
-
|
37 |
-
quantization_config = BitsAndBytesConfig(
|
38 |
-
load_in_4bit=True,
|
39 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
40 |
-
bnb_4bit_use_double_quant=True,
|
41 |
-
bnb_4bit_quant_type= "nf4")
|
42 |
-
|
43 |
-
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", quantization_config=quantization_config,
|
44 |
-
device_map="auto", torch_dtype=torch.bfloat16)
|
45 |
-
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
|
46 |
-
llm = HuggingFacePipeline(pipeline=pipe)
|
47 |
-
|
48 |
-
|
49 |
-
print('Model Loaded...')
|
50 |
-
print(f'Model Device: {model.device}')
|
51 |
|
52 |
def get_schemas():
|
53 |
schemas = conn.execute("""
|
@@ -79,225 +46,54 @@ def get_table_schema(table):
|
|
79 |
else:
|
80 |
old_path = table
|
81 |
ddl_create = ddl_create.replace(old_path, full_path)
|
82 |
-
return ddl_create
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
{
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
result = llm.invoke(prompt)
|
104 |
-
return result
|
105 |
-
|
106 |
-
|
107 |
-
def get_visualization_type(text_query, sql_query, sql_result):
|
108 |
-
system_message = '''
|
109 |
-
You are an AI assistant that recommends appropriate data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable type of graph or chart to visualize the data. If no visualization is appropriate, indicate that.
|
110 |
-
|
111 |
-
Available chart types and their use cases:
|
112 |
-
- Bar Graphs: Best for comparing categorical data or showing changes over time when categories are discrete and the number of categories is more than 2. Use for questions like "What are the sales figures for each product?" or "How does the population of cities compare? or "What percentage of each city is male?"
|
113 |
-
- Horizontal Bar Graphs: Best for comparing categorical data or showing changes over time when the number of categories is small or the disparity between categories is large. Use for questions like "Show the revenue of A and B?" or "How does the population of 2 cities compare?" or "How many men and women got promoted?" or "What percentage of men and what percentage of women got promoted?" when the disparity between categories is large.
|
114 |
-
- Scatter Plots: Useful for identifying relationships or correlations between two numerical variables or plotting distributions of data. Best used when both x axis and y axis are continuous. Use for questions like "Plot a distribution of the fares (where the x axis is the fare and the y axis is the count of people who paid that fare)" or "Is there a relationship between advertising spend and sales?" or "How do height and weight correlate in the dataset? Do not use it for questions that do not have a continuous x axis."
|
115 |
-
- Pie Charts: Ideal for showing proportions or percentages within a whole. Use for questions like "What is the market share distribution among different companies?" or "What percentage of the total revenue comes from each product?"
|
116 |
-
- Line Graphs: Best for showing trends and distributionsover time. Best used when both x axis and y axis are continuous. Used for questions like "How have website visits changed over the year?" or "What is the trend in temperature over the past decade?". Do not use it for questions that do not have a continuous x axis or a time based x axis.
|
117 |
-
|
118 |
-
Consider these types of questions when recommending a visualization:
|
119 |
-
1. Aggregations and Summarizations (e.g., "What is the average revenue by month?" - Line Graph)
|
120 |
-
2. Comparisons (e.g., "Compare the sales figures of Product A and Product B over the last year." - Line or Column Graph)
|
121 |
-
3. Plotting Distributions (e.g., "Plot a distribution of the age of users" - Scatter Plot)
|
122 |
-
4. Trends Over Time (e.g., "What is the trend in the number of active users over the past year?" - Line Graph)
|
123 |
-
5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
|
124 |
-
6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
|
125 |
-
|
126 |
-
Generate a JSON object. The JSON object should have the following structure:
|
127 |
-
|
128 |
-
Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
|
129 |
-
Reason: [Brief explanation for your recommendation]
|
130 |
-
|
131 |
-
'''
|
132 |
-
human_message = '''
|
133 |
-
User question: {question}
|
134 |
-
SQL query: {sql_query}
|
135 |
-
Query results: {results}
|
136 |
-
|
137 |
-
Recommend a visualization:
|
138 |
-
'''
|
139 |
-
|
140 |
-
prompt = ChatPromptTemplate.from_messages([
|
141 |
-
("system", system_message),
|
142 |
-
("human", human_message),
|
143 |
-
])
|
144 |
-
|
145 |
-
final_prompt = prompt.format_prompt(question=text_query,
|
146 |
-
sql_query=sql_query, results=sql_result)
|
147 |
-
response = run_llm(final_prompt)
|
148 |
-
response = response.replace('```', '')
|
149 |
-
response = response.replace('json', '')
|
150 |
-
json_data = json.loads(response)
|
151 |
-
visualization = json_data['Recommended Visualization']
|
152 |
-
reason = json_data['Reason']
|
153 |
-
print(visualization, reason)
|
154 |
-
return visualization, reason
|
155 |
-
|
156 |
-
def format_bar_data(results, question, sql_query):
|
157 |
-
if isinstance(results, str):
|
158 |
-
results = eval(results)
|
159 |
-
|
160 |
-
if len(results[0]) == 2:
|
161 |
-
labels = [str(row[0]) for row in results]
|
162 |
-
data = [float(row[1]) if row[1] is not None else 0 for row in results]
|
163 |
-
|
164 |
-
prompt = ChatPromptTemplate.from_messages([
|
165 |
-
("system", "You are a data labeling expert. Given a question, SQL Query used, and some data, provide a concise and relevant label for the data series."),
|
166 |
-
("human", "Question: {question}\n SQL Query: {sql_query} \nData (first few rows): {data}\n\n Provide a concise label name for this y axis. Just give me the name."),
|
167 |
-
])
|
168 |
-
prompt = prompt.format_prompt(question=question, data=str(results[:2]),
|
169 |
-
sql_query=sql_query)
|
170 |
-
label = run_llm(prompt)
|
171 |
-
label = label.replace('**Answer:**', '')
|
172 |
-
values = [{"data": data, "label": label}]
|
173 |
-
elif len(results[0]) == 3:
|
174 |
-
categories = set(row[1] for row in results)
|
175 |
-
labels = list(categories)
|
176 |
-
entities = set(row[0] for row in results)
|
177 |
-
values = []
|
178 |
-
for entity in entities:
|
179 |
-
entity_data = [float(row[2]) for row in results if row[0] == entity]
|
180 |
-
values.append({"data": entity_data, "label": str(entity)})
|
181 |
-
else:
|
182 |
-
raise ValueError("Unexpected data format in results")
|
183 |
-
|
184 |
-
formatted_data = {
|
185 |
-
"labels": labels,
|
186 |
-
"values": values
|
187 |
}
|
|
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user,Given the question, SQL query, and results, your job is to Understand the question and SQL query. Ensure the result matches the query and can be easily visualized."),
|
196 |
-
("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
|
197 |
-
])
|
198 |
-
|
199 |
-
|
200 |
-
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
|
201 |
-
results=sql_result, instructions=instruction)
|
202 |
-
formatted_data = run_llm(prompt)
|
203 |
-
formatted_data = formatted_data.replace('```', '')
|
204 |
-
formatted_data = formatted_data.replace('json', '')
|
205 |
-
print(f'Formatted Data {formatted_data}')
|
206 |
-
return json.loads(formatted_data.replace('.', '').strip())
|
207 |
-
|
208 |
-
def visualize_result(text_query, visualization_type, sql_query,
|
209 |
-
sql_result):
|
210 |
-
|
211 |
-
if visualization_type == 'bar':
|
212 |
-
try:
|
213 |
-
data = format_bar_data(sql_result, text_query, sql_query)
|
214 |
-
except Exception as e:
|
215 |
-
data = format_data(text_query=text_query, sql_query=sql_query,
|
216 |
-
sql_result=sql_result, visualization_type=visualization_type)
|
217 |
-
return plot_bar_chart(data)
|
218 |
-
|
219 |
-
elif visualization_type == 'horizontal_bar':
|
220 |
-
data = format_data(text_query=text_query, sql_query=sql_query,
|
221 |
-
sql_result=sql_result, visualization_type=visualization_type)
|
222 |
-
return plot_horizontal_bar_chart(data)
|
223 |
-
|
224 |
-
elif visualization_type == 'line':
|
225 |
-
data = format_data(text_query=text_query, sql_query=sql_query,
|
226 |
-
sql_result=sql_result, visualization_type=visualization_type)
|
227 |
-
return plot_line_graph(data)
|
228 |
-
|
229 |
-
elif visualization_type == 'pie':
|
230 |
-
data = format_data(text_query=text_query, sql_query=sql_query,
|
231 |
-
sql_result=sql_result, visualization_type=visualization_type)
|
232 |
-
return plot_pie(data)
|
233 |
-
|
234 |
-
elif visualization_type == 'scatter':
|
235 |
-
data = format_data(text_query=text_query, sql_query=sql_query,
|
236 |
-
sql_result=sql_result, visualization_type=visualization_type)
|
237 |
-
return plot_scatter(data)
|
238 |
-
|
239 |
-
elif visualization_type == 'none':
|
240 |
-
fig, ax = plt.subplots()
|
241 |
-
ax.set_visible(False)
|
242 |
-
return fig
|
243 |
-
|
244 |
-
|
245 |
|
246 |
def main(table, text_query):
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
ax.set_visible(False)
|
251 |
-
|
252 |
-
schema = get_table_schema(table)
|
253 |
-
prompt = get_prompt(schema, text_query)
|
254 |
-
try:
|
255 |
-
generated_sql_query = run_llm(prompt)
|
256 |
-
print(f'Generated SQL Query: {generated_sql_query}')
|
257 |
-
except Exception as e:
|
258 |
-
return generate_output(schema, prompt, '', fig, pd.DataFrame([{"error": f"❌ Unable to generate the SQL query. {e}"}]))
|
259 |
-
|
260 |
-
try:
|
261 |
-
sql_query_result_raw = conn.sql(generated_sql_query)
|
262 |
-
sql_query_result = sql_query_result_raw.fetchall()
|
263 |
-
sql_query_df = sql_query_result_raw.df()
|
264 |
-
|
265 |
-
print(f"SQL Query Result: {sql_query_result}")
|
266 |
-
except Exception as e:
|
267 |
-
return generate_output(schema, prompt, generated_sql_query, fig, pd.DataFrame([{"error": f"❌ Unable to execute the SQL query. {e}"}]))
|
268 |
-
|
269 |
-
if len(sql_query_result) >= 100:
|
270 |
-
gr.Warning(f"⚠️ Data is too large for visualization. Please refine your query.")
|
271 |
-
return generate_output(schema, prompt, generated_sql_query,
|
272 |
-
fig, sql_query_df)
|
273 |
|
274 |
try:
|
275 |
-
|
276 |
-
sql_query=generated_sql_query, sql_result=sql_query_result)
|
277 |
-
|
278 |
-
print(f"Visualization Type: {visualization_type}")
|
279 |
-
print(f"Reason: {reason}")
|
280 |
-
|
281 |
except Exception as e:
|
282 |
-
|
283 |
-
|
|
|
284 |
|
285 |
-
if visualization_type != 'none':
|
286 |
-
try:
|
287 |
-
plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
|
288 |
-
sql_result=sql_query_result, visualization_type=visualization_type)
|
289 |
-
|
290 |
-
except Exception as e:
|
291 |
-
gr.Warning(f"⚠️ {e}")
|
292 |
-
return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
|
293 |
-
else:
|
294 |
-
gr.Warning(f"⚠️ {reason}")
|
295 |
-
return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
|
296 |
-
|
297 |
-
return generate_output(schema, prompt, generated_sql_query, plot, sql_query_df)
|
298 |
-
|
299 |
-
def generate_output(schema, prompt, generated_sql_query, result_plot, sql_query_df):
|
300 |
-
return [schema, prompt, generated_sql_query, result_plot, sql_query_df]
|
301 |
|
302 |
custom_css = """
|
303 |
.gradio-container {
|
@@ -312,6 +108,7 @@ custom_css = """
|
|
312 |
background-color: #4a90e2 !important;
|
313 |
}
|
314 |
.gr-button:hover {
|
|
|
315 |
background-color: #3a7bc8 !important;
|
316 |
}
|
317 |
"""
|
@@ -329,32 +126,24 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
329 |
|
330 |
with gr.Row():
|
331 |
|
332 |
-
with gr.Column(scale=1
|
333 |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
|
334 |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
|
335 |
|
336 |
with gr.Column(scale=2):
|
337 |
-
query_input = gr.Textbox(lines=
|
338 |
with gr.Row():
|
339 |
with gr.Column(scale=7):
|
340 |
pass
|
341 |
with gr.Column(scale=1):
|
342 |
generate_query_button = gr.Button("Run Query", variant="primary")
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
with gr.Tab("Plot"):
|
348 |
-
result_plot = gr.Plot()
|
349 |
-
with gr.Tab("SQL Query"):
|
350 |
-
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
|
351 |
-
with gr.Tab("Prompt"):
|
352 |
-
input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
|
353 |
-
with gr.Tab("Schema"):
|
354 |
-
table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
|
355 |
|
356 |
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
|
357 |
-
generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[
|
358 |
|
359 |
if __name__ == "__main__":
|
360 |
demo.launch(debug=True)
|
|
|
1 |
import os
|
|
|
|
|
2 |
import duckdb
|
|
|
|
|
3 |
import gradio as gr
|
4 |
+
from transformers import HfEngine, ReactCodeAgent
|
5 |
+
from transformers.agents import Tool
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Height of the Tabs Text Area
|
8 |
TAB_LINES = 8
|
|
|
14 |
# Connect to DB
|
15 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
16 |
|
17 |
+
llm_engine = HfEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def get_schemas():
|
20 |
schemas = conn.execute("""
|
|
|
46 |
else:
|
47 |
old_path = table
|
48 |
ddl_create = ddl_create.replace(old_path, full_path)
|
49 |
+
return ddl_create, full_path
|
50 |
+
|
51 |
+
def get_visualization(question, tool):
|
52 |
+
agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True,
|
53 |
+
additional_authorized_imports=['matplotlib.pyplot',
|
54 |
+
'pandas', 'plotly.express',
|
55 |
+
'seaborn'], max_iterations=20)
|
56 |
+
fig = agent.run(
|
57 |
+
task=f'''
|
58 |
+
Use seaborn. Always
|
59 |
+
Question: {question}
|
60 |
+
Always use the right colors.
|
61 |
+
If the question is about showing n number of rows return empty figure.
|
62 |
+
In the end you have to return a final fig using the `final_answer` tool
|
63 |
+
''',
|
64 |
+
)
|
65 |
|
66 |
+
return fig
|
67 |
+
|
68 |
+
class SQLExecutorTool(Tool):
|
69 |
+
name = "sql_engine"
|
70 |
+
inputs = {
|
71 |
+
"query": {
|
72 |
+
"type": "text",
|
73 |
+
"description": f"The query to perform. This should be correct DuckDB SQL.",
|
74 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
}
|
76 |
+
output_type = "pandas.core.frame.DataFrame"
|
77 |
|
78 |
+
def forward(self, query: str) -> str:
|
79 |
+
with duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) as con:
|
80 |
+
output_df = conn.sql(query).df()
|
81 |
+
return output_df
|
82 |
+
|
83 |
+
tool = SQLExecutorTool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def main(table, text_query):
|
86 |
+
schema, _ = get_table_schema(table)
|
87 |
+
tool.description = f"""Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result.
|
88 |
+
The table schema is as follows: \n{schema}"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
try:
|
91 |
+
fig = get_visualization(question=text_query, tool=tool)
|
|
|
|
|
|
|
|
|
|
|
92 |
except Exception as e:
|
93 |
+
gr.Warning(f"❌ Unable to generate the visualization. {e}")
|
94 |
+
|
95 |
+
return fig
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
custom_css = """
|
99 |
.gradio-container {
|
|
|
108 |
background-color: #4a90e2 !important;
|
109 |
}
|
110 |
.gr-button:hover {
|
111 |
+
|
112 |
background-color: #3a7bc8 !important;
|
113 |
}
|
114 |
"""
|
|
|
126 |
|
127 |
with gr.Row():
|
128 |
|
129 |
+
with gr.Column(scale=1):
|
130 |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
|
131 |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
|
132 |
|
133 |
with gr.Column(scale=2):
|
134 |
+
query_input = gr.Textbox(lines=3, label="Text Query", placeholder="Enter your text query here...")
|
135 |
with gr.Row():
|
136 |
with gr.Column(scale=7):
|
137 |
pass
|
138 |
with gr.Column(scale=1):
|
139 |
generate_query_button = gr.Button("Run Query", variant="primary")
|
140 |
|
141 |
+
with gr.Tabs():
|
142 |
+
with gr.Tab("Plot"):
|
143 |
+
result_plot = gr.Plot()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
|
146 |
+
generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot])
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
demo.launch(debug=True)
|
plot_utils.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
def plot_bar_chart(data):
|
5 |
-
fig, ax = plt.subplots()
|
6 |
-
labels = data['labels']
|
7 |
-
values = data['values']
|
8 |
-
|
9 |
-
colors = ['darkviolet', 'indigo', 'blueviolet']
|
10 |
-
width = 0.2
|
11 |
-
x = np.arange(len(labels))
|
12 |
-
|
13 |
-
for i, value in enumerate(values):
|
14 |
-
ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])
|
15 |
-
|
16 |
-
ax.set_xticks(x + width / 2 * (len(values) - 1))
|
17 |
-
ax.set_xticklabels(labels)
|
18 |
-
|
19 |
-
ax.set_title('Bar Chart')
|
20 |
-
ax.set_xlabel('Labels')
|
21 |
-
ax.set_ylabel('Values')
|
22 |
-
ax.legend()
|
23 |
-
return fig
|
24 |
-
|
25 |
-
def plot_horizontal_bar_chart(data):
|
26 |
-
fig, ax = plt.subplots()
|
27 |
-
labels = data['labels']
|
28 |
-
values = data['values']
|
29 |
-
|
30 |
-
for value in values:
|
31 |
-
ax.barh(labels, value['data'], label=value['label'], color=['darkviolet', 'indigo', 'blueviolet'])
|
32 |
-
|
33 |
-
ax.set_title('Horizontal Bar Chart')
|
34 |
-
ax.set_xlabel('Values')
|
35 |
-
ax.set_ylabel('Labels')
|
36 |
-
ax.legend()
|
37 |
-
return fig
|
38 |
-
|
39 |
-
def plot_line_graph(data):
|
40 |
-
fig, ax = plt.subplots()
|
41 |
-
x_values = data['xValues']
|
42 |
-
y_values_list = data['yValues']
|
43 |
-
|
44 |
-
for y_values in y_values_list:
|
45 |
-
label = y_values.get('label', None)
|
46 |
-
ax.plot(x_values, y_values['data'], label=label)
|
47 |
-
|
48 |
-
ax.set_title('Line Graph')
|
49 |
-
ax.set_xlabel('X Values')
|
50 |
-
ax.set_ylabel('Y Values')
|
51 |
-
|
52 |
-
if any('label' in y_values for y_values in y_values_list):
|
53 |
-
ax.legend()
|
54 |
-
|
55 |
-
return fig
|
56 |
-
|
57 |
-
def plot_scatter(data):
|
58 |
-
fig, ax = plt.subplots()
|
59 |
-
for series in data['series']:
|
60 |
-
x_values = [point['x'] for point in series['data']]
|
61 |
-
y_values = [point['y'] for point in series['data']]
|
62 |
-
label = series.get('label', None)
|
63 |
-
ax.scatter(x_values, y_values, label=label)
|
64 |
-
|
65 |
-
ax.set_title('Scatter Plot')
|
66 |
-
ax.set_xlabel('X Values')
|
67 |
-
ax.set_ylabel('Y Values')
|
68 |
-
|
69 |
-
if any('label' in series for series in data['series']):
|
70 |
-
ax.legend()
|
71 |
-
|
72 |
-
return fig
|
73 |
-
|
74 |
-
def plot_pie(data):
|
75 |
-
values = [item['value'] for item in data]
|
76 |
-
labels = [item['label'] for item in data]
|
77 |
-
|
78 |
-
fig, ax = plt.subplots()
|
79 |
-
wedges, texts, autotexts = ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=140)
|
80 |
-
ax.set_title('Pie Chart')
|
81 |
-
|
82 |
-
for text in texts + autotexts:
|
83 |
-
text.set_fontsize(10)
|
84 |
-
|
85 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,8 +1,3 @@
|
|
1 |
huggingface_hub
|
2 |
-
accelerate
|
3 |
-
bitsandbytes
|
4 |
transformers
|
5 |
-
duckdb
|
6 |
-
langchain-huggingface
|
7 |
-
langchain-core
|
8 |
-
sentencepiece
|
|
|
1 |
huggingface_hub
|
|
|
|
|
2 |
transformers
|
3 |
+
duckdb
|
|
|
|
|
|
visualization_prompt.py
DELETED
@@ -1,154 +0,0 @@
|
|
1 |
-
barGraphIntstruction = '''
|
2 |
-
|
3 |
-
Where data is: {
|
4 |
-
labels: string[]
|
5 |
-
values: {\data: number[], label: string}[]
|
6 |
-
}
|
7 |
-
|
8 |
-
The output must follow this format strictly, even if the input data differs from the examples below.
|
9 |
-
// Examples of usage:
|
10 |
-
Each label represents a column on the x axis.
|
11 |
-
Each array in values represents a different entity.
|
12 |
-
|
13 |
-
Here we are looking at average income for each month.
|
14 |
-
{
|
15 |
-
labels: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
|
16 |
-
values: [{data:[21.5, 25.0, 47.5, 64.8, 105.5, 133.2], label: 'Income'}],
|
17 |
-
}
|
18 |
-
|
19 |
-
Here we are looking at the performance of american and european players for each series. Since there are two entities, we have two arrays in values.
|
20 |
-
{
|
21 |
-
labels: ['series A', 'series B', 'series C'],
|
22 |
-
values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
|
23 |
-
}
|
24 |
-
|
25 |
-
The output format must be consistent with this structure, regardless of the specific input data.
|
26 |
-
'''
|
27 |
-
|
28 |
-
horizontalBarGraphIntstruction = '''
|
29 |
-
|
30 |
-
Where data is: {
|
31 |
-
labels: string[]
|
32 |
-
values: {\data: number[], label: string}[]
|
33 |
-
}
|
34 |
-
|
35 |
-
// Examples of usage:
|
36 |
-
Each label represents a column on the x axis.
|
37 |
-
Each array in values represents a different entity.
|
38 |
-
|
39 |
-
Here we are looking at average income for each month.
|
40 |
-
{
|
41 |
-
labels: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
|
42 |
-
values: [{data:[21.5, 25.0, 47.5, 64.8, 105.5, 133.2], label: 'Income'}],
|
43 |
-
}
|
44 |
-
|
45 |
-
Here we are looking at the performance of american and european players for each series. Since there are two entities, we have two arrays in values.
|
46 |
-
{
|
47 |
-
labels: ['series A', 'series B', 'series C'],
|
48 |
-
values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
|
49 |
-
}
|
50 |
-
|
51 |
-
'''
|
52 |
-
|
53 |
-
|
54 |
-
lineGraphIntstruction = '''
|
55 |
-
|
56 |
-
Where data is: {
|
57 |
-
xValues: number[] | string[]
|
58 |
-
yValues: { data: number[]; label: string }[]
|
59 |
-
}
|
60 |
-
|
61 |
-
// Examples of usage:
|
62 |
-
|
63 |
-
Here we are looking at the momentum of a body as a function of mass.
|
64 |
-
{
|
65 |
-
xValues: ['2020', '2021', '2022', '2023', '2024'],
|
66 |
-
yValues: [
|
67 |
-
{ data: [2, 5.5, 2, 8.5, 1.5]},
|
68 |
-
],
|
69 |
-
}
|
70 |
-
|
71 |
-
Here we are looking at the performance of american and european players for each year. Since there are two entities, we have two arrays in yValues.
|
72 |
-
{
|
73 |
-
xValues: ['2020', '2021', '2022', '2023', '2024'],
|
74 |
-
yValues: [
|
75 |
-
{ data: [2, 5.5, 2, 8.5, 1.5], label: 'American' },
|
76 |
-
{ data: [2, 5.5, 2, 8.5, 1.5], label: 'European' },
|
77 |
-
],
|
78 |
-
}
|
79 |
-
'''
|
80 |
-
|
81 |
-
pieChartIntstruction = '''
|
82 |
-
|
83 |
-
Where data is: {
|
84 |
-
labels: string
|
85 |
-
values: number
|
86 |
-
}[]
|
87 |
-
|
88 |
-
// Example usage:
|
89 |
-
[
|
90 |
-
{ id: 0, value: 10, label: 'series A' },
|
91 |
-
{ id: 1, value: 15, label: 'series B' },
|
92 |
-
{ id: 2, value: 20, label: 'series C' },
|
93 |
-
],
|
94 |
-
'''
|
95 |
-
|
96 |
-
scatterPlotIntstruction = '''
|
97 |
-
Where data is: {
|
98 |
-
series: {
|
99 |
-
data: { x: number; y: number; id: number }[]
|
100 |
-
label: string
|
101 |
-
}[]
|
102 |
-
}
|
103 |
-
|
104 |
-
// Examples of usage:
|
105 |
-
1. Here each data array represents the points for a different entity.
|
106 |
-
We are looking for correlation between amount spent and quantity bought for men and women.
|
107 |
-
{
|
108 |
-
series: [
|
109 |
-
{
|
110 |
-
data: [
|
111 |
-
{ x: 100, y: 200, id: 1 },
|
112 |
-
{ x: 120, y: 100, id: 2 },
|
113 |
-
{ x: 170, y: 300, id: 3 },
|
114 |
-
],
|
115 |
-
label: 'Men',
|
116 |
-
},
|
117 |
-
{
|
118 |
-
data: [
|
119 |
-
{ x: 300, y: 300, id: 1 },
|
120 |
-
{ x: 400, y: 500, id: 2 },
|
121 |
-
{ x: 200, y: 700, id: 3 },
|
122 |
-
],
|
123 |
-
label: 'Women',
|
124 |
-
}
|
125 |
-
],
|
126 |
-
}
|
127 |
-
|
128 |
-
2. Here we are looking for correlation between the height and weight of players.
|
129 |
-
{
|
130 |
-
series: [
|
131 |
-
{
|
132 |
-
data: [
|
133 |
-
{ x: 180, y: 80, id: 1 },
|
134 |
-
{ x: 170, y: 70, id: 2 },
|
135 |
-
{ x: 160, y: 60, id: 3 },
|
136 |
-
],
|
137 |
-
label: 'Players',
|
138 |
-
},
|
139 |
-
],
|
140 |
-
}
|
141 |
-
|
142 |
-
// Note: Each object in the 'data' array represents a point on the scatter plot.
|
143 |
-
// The 'x' and 'y' values determine the position of the point, and 'id' is a unique identifier.
|
144 |
-
// Multiple series can be represented, each as an object in the outer array.
|
145 |
-
'''
|
146 |
-
|
147 |
-
|
148 |
-
graph_instructions = {
|
149 |
-
"bar": barGraphIntstruction,
|
150 |
-
"horizontal_bar": horizontalBarGraphIntstruction,
|
151 |
-
"line": lineGraphIntstruction,
|
152 |
-
"pie": pieChartIntstruction,
|
153 |
-
"scatter": scatterPlotIntstruction
|
154 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|