Condense query functions

#42
functions/__init__.py CHANGED
@@ -1,9 +1,9 @@
1
- from .query_functions import SQLiteQuery, sqlite_query_func, sql_query_func, doc_db_query_func, graphql_query_func, graphql_schema_query, graphql_csv_query
2
  from .chart_functions import table_generation_func, scatter_chart_generation_func, \
3
  line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
4
  from .chat_functions import example_question_generator, chatbot_func
5
  from .stat_functions import regression_func
6
 
7
- __all__ = ["SQLiteQuery","sqlite_query_func","sql_query_func","doc_db_query_func","graphql_query_func","graphql_schema_query","graphql_csv_query","table_generation_func","scatter_chart_generation_func",
8
  "line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
9
  "scatter_chart_fig","example_question_generator","chatbot_func"]
 
1
+ from .query_functions import graphql_schema_query, graphql_csv_query, query_func
2
  from .chart_functions import table_generation_func, scatter_chart_generation_func, \
3
  line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
4
  from .chat_functions import example_question_generator, chatbot_func
5
  from .stat_functions import regression_func
6
 
7
+ __all__ = ["query_func","graphql_schema_query","graphql_csv_query","table_generation_func","scatter_chart_generation_func",
8
  "line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
9
  "scatter_chart_fig","example_question_generator","chatbot_func"]
functions/chat_functions.py CHANGED
@@ -62,7 +62,8 @@ def example_question_generator(session_hash, data_source, name, titles, schema):
62
  return example_response["replies"][0].text
63
 
64
  def system_message(data_source, titles, schema=""):
65
-
 
66
  system_message_dict = {
67
  'file_upload' : f"""You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source' that contains the following columns: {titles}.
68
  You also have access to a function, called table_generation_func, that can take a query.csv file generated from our sql query and returns an iframe that we should display in our chat window.
@@ -111,13 +112,12 @@ def system_message(data_source, titles, schema=""):
111
  return system_message_dict[data_source]
112
 
113
  def chatbot_func(message, history, session_hash, data_source, titles, schema, *args):
114
- from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
115
- sql_query_func, doc_db_query_func, graphql_query_func, graphql_schema_query, graphql_csv_query, \
116
  line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
117
  import tools.tools as tools
118
 
119
- available_functions = {"sqlite_query_func": sqlite_query_func,"sql_query_func": sql_query_func,"doc_db_query_func": doc_db_query_func,
120
- "graphql_query_func": graphql_query_func,"graphql_schema_query": graphql_schema_query,"graphql_csv_query": graphql_csv_query,
121
  "table_generation_func":table_generation_func,
122
  "line_chart_generation_func":line_chart_generation_func,"bar_chart_generation_func":bar_chart_generation_func,
123
  "scatter_chart_generation_func":scatter_chart_generation_func, "pie_chart_generation_func":pie_chart_generation_func,
 
62
  return example_response["replies"][0].text
63
 
64
  def system_message(data_source, titles, schema=""):
65
+ print("TITLES")
66
+ print(titles)
67
  system_message_dict = {
68
  'file_upload' : f"""You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source' that contains the following columns: {titles}.
69
  You also have access to a function, called table_generation_func, that can take a query.csv file generated from our sql query and returns an iframe that we should display in our chat window.
 
112
  return system_message_dict[data_source]
113
 
114
  def chatbot_func(message, history, session_hash, data_source, titles, schema, *args):
115
+ from functions import table_generation_func, regression_func, scatter_chart_generation_func, \
116
+ query_func, graphql_schema_query, graphql_csv_query, \
117
  line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
118
  import tools.tools as tools
119
 
120
+ available_functions = {"query_func":query_func,"graphql_schema_query": graphql_schema_query,"graphql_csv_query": graphql_csv_query,
 
121
  "table_generation_func":table_generation_func,
122
  "line_chart_generation_func":line_chart_generation_func,"bar_chart_generation_func":bar_chart_generation_func,
123
  "scatter_chart_generation_func":scatter_chart_generation_func, "pie_chart_generation_func":pie_chart_generation_func,
functions/query_functions.py CHANGED
@@ -35,28 +35,6 @@ class SQLiteQuery:
35
  self.connection.close()
36
  return {"results": results, "queries": queries, "csv_columns": column_names}
37
 
38
-
39
-
40
- def sqlite_query_func(queries: List[str], session_hash, **kwargs):
41
- dir_path = TEMP_DIR / str(session_hash)
42
- sql_query = SQLiteQuery(f'{dir_path}/file_upload/data_source.db')
43
- try:
44
- result = sql_query.run(queries, session_hash)
45
- if len(result["results"][0]) > 1000:
46
- print("QUERY TOO LARGE")
47
- return {"reply": f"""query result too large to be processed by llm, the query results are in our query.csv file.
48
- The column names of this query.csv file are: {result["csv_columns"]}.
49
- If you need to display the results directly, perhaps use the table_generation_func function."""}
50
- else:
51
- return {"reply": result["results"][0]}
52
-
53
- except Exception as e:
54
- reply = f"""There was an error running the SQL Query = {queries}
55
- The error is {e},
56
- You should probably try again.
57
- """
58
- return {"reply": reply}
59
-
60
  @component
61
  class PostgreSQLQuery:
62
 
@@ -82,30 +60,6 @@ class PostgreSQLQuery:
82
  results.append(f"{result}")
83
  self.connection.close()
84
  return {"results": results, "queries": queries, "csv_columns": column_names}
85
-
86
-
87
-
88
- def sql_query_func(queries: List[str], session_hash, args, **kwargs):
89
- sql_query = PostgreSQLQuery(args[0], args[1], args[2], args[3], args[4])
90
- try:
91
- result = sql_query.run(queries, session_hash)
92
- print("RESULT")
93
- print(result)
94
- if len(result["results"][0]) > 1000:
95
- print("QUERY TOO LARGE")
96
- return {"reply": f"""query result too large to be processed by llm, the query results are in our query.csv file.
97
- The column names of this query.csv file are: {result["csv_columns"]}.
98
- If you need to display the results directly, perhaps use the table_generation_func function."""}
99
- else:
100
- return {"reply": result["results"][0]}
101
-
102
- except Exception as e:
103
- reply = f"""There was an error running the SQL Query = {queries}
104
- The error is {e},
105
- You should probably try again.
106
- """
107
- print(reply)
108
- return {"reply": reply}
109
 
110
  @component
111
  class DocDBQuery:
@@ -155,29 +109,6 @@ class DocDBQuery:
155
  self.client.close()
156
  return {"results": results, "queries": aggregation_pipeline, "csv_columns": column_names}
157
 
158
-
159
-
160
- def doc_db_query_func(aggregation_pipeline: List[str], db_collection: AnyStr, session_hash, args, **kwargs):
161
- doc_db_query = DocDBQuery(args[0], args[1])
162
- try:
163
- result = doc_db_query.run(aggregation_pipeline, db_collection, session_hash)
164
- print("RESULT")
165
- if len(result["results"][0]) > 1000:
166
- print("QUERY TOO LARGE")
167
- return {"reply": f"""query result too large to be processed by llm, the query results are in our query.csv file.
168
- The column names of this query.csv file are: {result["csv_columns"]}.
169
- If you need to display the results directly, perhaps use the table_generation_func function."""}
170
- else:
171
- return {"reply": result["results"][0]}
172
-
173
- except Exception as e:
174
- reply = f"""There was an error running the NoSQL (Mongo) Query = {aggregation_pipeline}
175
- The error is {e},
176
- You should probably try again.
177
- """
178
- print(reply)
179
- return {"reply": reply}
180
-
181
  @component
182
  class GraphQLQuery:
183
 
@@ -214,12 +145,23 @@ class GraphQLQuery:
214
  results.append(f"{response_frame}")
215
  return {"results": results, "queries": graphql_query, "csv_columns": column_names}
216
 
217
-
218
-
219
- def graphql_query_func(graphql_query: AnyStr, session_hash, args, **kwargs):
220
- graphql_object = GraphQLQuery()
221
  try:
222
- result = graphql_object.run(graphql_query, args[0], args[1], args[2], session_hash)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  print("RESULT")
224
  if len(result["results"][0]) > 1000:
225
  print("QUERY TOO LARGE")
@@ -230,7 +172,7 @@ def graphql_query_func(graphql_query: AnyStr, session_hash, args, **kwargs):
230
  return {"reply": result["results"][0]}
231
 
232
  except Exception as e:
233
- reply = f"""There was an error running the GraphQL Query = {graphql_query}
234
  The error is {e},
235
  You should probably try again.
236
  """
@@ -266,6 +208,7 @@ def graphql_csv_query(csv_query: AnyStr, session_hash, **kwargs):
266
  query = pd.read_csv(f'{dir_path}/graphql/query.csv')
267
  query.Name = 'query'
268
  print("GRAPHQL CSV QUERY")
 
269
  queried_df = sqldf(csv_query, locals())
270
  print(queried_df)
271
  column_names = list(queried_df.columns)
 
35
  self.connection.close()
36
  return {"results": results, "queries": queries, "csv_columns": column_names}
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @component
39
  class PostgreSQLQuery:
40
 
 
60
  results.append(f"{result}")
61
  self.connection.close()
62
  return {"results": results, "queries": queries, "csv_columns": column_names}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  @component
65
  class DocDBQuery:
 
109
  self.client.close()
110
  return {"results": results, "queries": aggregation_pipeline, "csv_columns": column_names}
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  @component
113
  class GraphQLQuery:
114
 
 
145
  results.append(f"{response_frame}")
146
  return {"results": results, "queries": graphql_query, "csv_columns": column_names}
147
 
148
+ def query_func(queries:List[str], session_hash, session_folder, args, **kwargs):
 
 
 
149
  try:
150
+ print("QUERY")
151
+ print(queries)
152
+ if session_folder == "file_upload":
153
+ dir_path = TEMP_DIR / str(session_hash)
154
+ sql_query = SQLiteQuery(f'{dir_path}/file_upload/data_source.db')
155
+ result = sql_query.run(queries, session_hash)
156
+ elif session_folder == "sql":
157
+ sql_query = PostgreSQLQuery(args[0], args[1], args[2], args[3], args[4])
158
+ result = sql_query.run(queries, session_hash)
159
+ elif session_folder == 'doc_db':
160
+ doc_db_query = DocDBQuery(args[0], args[1])
161
+ result = doc_db_query.run(queries, kwargs['db_collection'], session_hash)
162
+ elif session_folder == 'graphql':
163
+ graphql_object = GraphQLQuery()
164
+ result = graphql_object.run(queries, args[0], args[1], args[2], session_hash)
165
  print("RESULT")
166
  if len(result["results"][0]) > 1000:
167
  print("QUERY TOO LARGE")
 
172
  return {"reply": result["results"][0]}
173
 
174
  except Exception as e:
175
+ reply = f"""There was an error running the {session_folder} Query = {queries}
176
  The error is {e},
177
  You should probably try again.
178
  """
 
208
  query = pd.read_csv(f'{dir_path}/graphql/query.csv')
209
  query.Name = 'query'
210
  print("GRAPHQL CSV QUERY")
211
+ print(csv_query)
212
  queried_df = sqldf(csv_query, locals())
213
  print(queried_df)
214
  column_names = list(queried_df.columns)
templates/data_file.py CHANGED
@@ -1,136 +1,136 @@
1
- import gradio as gr
2
- from functions import example_question_generator, chatbot_func
3
- from data_sources import process_data_upload
4
- from utils import message_dict
5
- import ast
6
-
7
- def run_example(input):
8
- return input
9
-
10
- def example_display(input):
11
- if input == None:
12
- display = True
13
- else:
14
- display = False
15
- return [gr.update(visible=display),gr.update(visible=display),gr.update(visible=display),gr.update(visible=display)]
16
-
17
- with gr.Blocks() as demo:
18
- description = gr.HTML("""
19
- <!-- Header -->
20
- <div class="max-w-4xl mx-auto mb-12 text-center">
21
- <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
22
- <h2 class="font-semibold text-blue-800 ">
23
- <i class="fas fa-info-circle mr-2"></i>Supported Files
24
- </h2>
25
- <div class="flex flex-wrap justify-center gap-3 pb-4 text-blue-700">
26
- <span class="tooltip">
27
- <i class="fas fa-file-csv mr-1"></i>CSV
28
- <span class="tooltip-text">Comma-separated values</span>
29
- </span>
30
- <span class="tooltip">
31
- <i class="fas fa-file-alt mr-1"></i>TSV
32
- <span class="tooltip-text">Tab-separated values</span>
33
- </span>
34
- <span class="tooltip">
35
- <i class="fas fa-file-alt mr-1"></i>TXT
36
- <span class="tooltip-text">Text files</span>
37
- </span>
38
- <span class="tooltip">
39
- <i class="fas fa-file-excel mr-1"></i>XLS/XLSX
40
- <span class="tooltip-text">Excel spreadsheets</span>
41
- </span>
42
- <span class="tooltip">
43
- <i class="fas fa-file-code mr-1"></i>XML
44
- <span class="tooltip-text">XML documents</span>
45
- </span>
46
- <span class="tooltip">
47
- <i class="fas fa-file-code mr-1"></i>JSON
48
- <span class="tooltip-text">JSON data files</span>
49
- </span>
50
- </div>
51
- </div>
52
- </div>
53
- """, elem_classes="description_component")
54
- example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
55
- example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
56
- example_file_3 = gr.File(visible=False, value="samples/tb_illness_data.csv")
57
- with gr.Row():
58
- example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
59
- example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
60
- example_btn_3 = gr.Button(value="Try Me: tb_illness_data.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
61
-
62
- file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker drop-zone border-2 border-dashed border-gray-300 rounded-lg hover:border-primary cursor-pointer bg-gray-50 hover:bg-blue-50 transition-colors duration-300", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
63
- example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
64
- example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
65
- example_btn_3.click(fn=run_example, inputs=example_file_3, outputs=file_output)
66
- file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2, example_btn_3, description])
67
-
68
- @gr.render(inputs=file_output)
69
- def data_options(filename, request: gr.Request):
70
- print(filename)
71
- if request.session_hash not in message_dict:
72
- message_dict[request.session_hash] = {}
73
- message_dict[request.session_hash]['file_upload'] = None
74
- if filename:
75
- process_message = process_upload(filename, request.session_hash)
76
- gr.HTML(value=process_message[1], padding=False)
77
- if process_message[0] == "success":
78
- if "bank_marketing_campaign" in filename:
79
- example_questions = [
80
- ["Describe the dataset"],
81
- ["What levels of education have the highest and lowest average balance?"],
82
- ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
83
- ["Can you generate a bar chart of education vs. average balance?"],
84
- ["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"],
85
- ["Can we predict the relationship between the number of contacts performed before this campaign and the average balance?"],
86
- ["Can you plot the number of contacts performed before this campaign versus the duration and use balance as the size in a bubble chart?"]
87
- ]
88
- elif "online_retail_data" in filename:
89
- example_questions = [
90
- ["Describe the dataset"],
91
- ["What month had the highest revenue?"],
92
- ["Is revenue higher in the morning or afternoon?"],
93
- ["Can you generate a line graph of revenue per month?"],
94
- ["Can you generate a table of revenue per month?"],
95
- ["Can we predict how time of day affects transaction value in this data set?"],
96
- ["Can you plot revenue per month with size being the number of units sold that month in a bubble chart?"]
97
- ]
98
- else:
99
- try:
100
- generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'file_upload', '', process_message[1], ''))
101
- example_questions = [
102
- ["Describe the dataset"]
103
- ]
104
- for example in generated_examples:
105
- example_questions.append([example])
106
- except Exception as e:
107
- print("DATA FILE QUESTION GENERATION ERROR")
108
- print(e)
109
- example_questions = [
110
- ["Describe the dataset"],
111
- ["List the columns in the dataset"],
112
- ["What could this data be used for?"],
113
- ]
114
- session_hash = gr.Textbox(visible=False, value=request.session_hash)
115
- data_source = gr.Textbox(visible=False, value='file_upload')
116
- schema = gr.Textbox(visible=False, value='')
117
- titles = gr.Textbox(value=process_message[2], interactive=False, visible=False)
118
- bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
119
- chat = gr.ChatInterface(
120
- fn=chatbot_func,
121
- type='messages',
122
- chatbot=bot,
123
- title="Chat with your data file",
124
- concurrency_limit=None,
125
- examples=example_questions,
126
- additional_inputs=[session_hash, data_source, titles, schema]
127
- )
128
-
129
- def process_upload(upload_value, session_hash):
130
- if upload_value:
131
- process_message = process_data_upload(upload_value, session_hash)
132
- return process_message
133
-
134
-
135
- if __name__ == "__main__":
136
  demo.launch()
 
1
+ import gradio as gr
2
+ from functions import example_question_generator, chatbot_func
3
+ from data_sources import process_data_upload
4
+ from utils import message_dict
5
+ import ast
6
+
7
+ def run_example(input):
8
+ return input
9
+
10
+ def example_display(input):
11
+ if input == None:
12
+ display = True
13
+ else:
14
+ display = False
15
+ return [gr.update(visible=display),gr.update(visible=display),gr.update(visible=display),gr.update(visible=display)]
16
+
17
+ with gr.Blocks() as demo:
18
+ description = gr.HTML("""
19
+ <!-- Header -->
20
+ <div class="max-w-4xl mx-auto mb-12 text-center">
21
+ <div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
22
+ <h2 class="font-semibold text-blue-800 ">
23
+ <i class="fas fa-info-circle mr-2"></i>Supported Files
24
+ </h2>
25
+ <div class="flex flex-wrap justify-center gap-3 pb-4 text-blue-700">
26
+ <span class="tooltip">
27
+ <i class="fas fa-file-csv mr-1"></i>CSV
28
+ <span class="tooltip-text">Comma-separated values</span>
29
+ </span>
30
+ <span class="tooltip">
31
+ <i class="fas fa-file-alt mr-1"></i>TSV
32
+ <span class="tooltip-text">Tab-separated values</span>
33
+ </span>
34
+ <span class="tooltip">
35
+ <i class="fas fa-file-alt mr-1"></i>TXT
36
+ <span class="tooltip-text">Text files</span>
37
+ </span>
38
+ <span class="tooltip">
39
+ <i class="fas fa-file-excel mr-1"></i>XLS/XLSX
40
+ <span class="tooltip-text">Excel spreadsheets</span>
41
+ </span>
42
+ <span class="tooltip">
43
+ <i class="fas fa-file-code mr-1"></i>XML
44
+ <span class="tooltip-text">XML documents</span>
45
+ </span>
46
+ <span class="tooltip">
47
+ <i class="fas fa-file-code mr-1"></i>JSON
48
+ <span class="tooltip-text">JSON data files</span>
49
+ </span>
50
+ </div>
51
+ </div>
52
+ </div>
53
+ """, elem_classes="description_component")
54
+ example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
55
+ example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
56
+ example_file_3 = gr.File(visible=False, value="samples/tb_illness_data.csv")
57
+ with gr.Row():
58
+ example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
59
+ example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
60
+ example_btn_3 = gr.Button(value="Try Me: tb_illness_data.csv", elem_classes="sample-btn bg-gradient-to-r from-purple-500 to-indigo-600 text-white p-6 rounded-lg text-left hover:shadow-lg", size="md", variant="primary")
61
+
62
+ file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker drop-zone border-2 border-dashed border-gray-300 rounded-lg hover:border-primary cursor-pointer bg-gray-50 hover:bg-blue-50 transition-colors duration-300", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
63
+ example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
64
+ example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
65
+ example_btn_3.click(fn=run_example, inputs=example_file_3, outputs=file_output)
66
+ file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2, example_btn_3, description])
67
+
68
+ @gr.render(inputs=file_output)
69
+ def data_options(filename, request: gr.Request):
70
+ print(filename)
71
+ if request.session_hash not in message_dict:
72
+ message_dict[request.session_hash] = {}
73
+ message_dict[request.session_hash]['file_upload'] = None
74
+ if filename:
75
+ process_message = process_upload(filename, request.session_hash)
76
+ gr.HTML(value=process_message[1], padding=False)
77
+ if process_message[0] == "success":
78
+ if "bank_marketing_campaign" in filename:
79
+ example_questions = [
80
+ ["Describe the dataset"],
81
+ ["What levels of education have the highest and lowest average balance?"],
82
+ ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
83
+ ["Can you generate a bar chart of education vs. average balance?"],
84
+ ["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"],
85
+ ["Can we predict the relationship between the number of contacts performed before this campaign and the average balance?"],
86
+ ["Can you plot the number of contacts performed before this campaign versus the duration and use balance as the size in a bubble chart?"]
87
+ ]
88
+ elif "online_retail_data" in filename:
89
+ example_questions = [
90
+ ["Describe the dataset"],
91
+ ["What month had the highest revenue?"],
92
+ ["Is revenue higher in the morning or afternoon?"],
93
+ ["Can you generate a line graph of revenue per month?"],
94
+ ["Can you generate a table of revenue per month?"],
95
+ ["Can we predict how time of day affects transaction value in this data set?"],
96
+ ["Can you plot revenue per month with size being the number of units sold that month in a bubble chart?"]
97
+ ]
98
+ else:
99
+ try:
100
+ generated_examples = ast.literal_eval(example_question_generator(request.session_hash, 'file_upload', '', process_message[1], ''))
101
+ example_questions = [
102
+ ["Describe the dataset"]
103
+ ]
104
+ for example in generated_examples:
105
+ example_questions.append([example])
106
+ except Exception as e:
107
+ print("DATA FILE QUESTION GENERATION ERROR")
108
+ print(e)
109
+ example_questions = [
110
+ ["Describe the dataset"],
111
+ ["List the columns in the dataset"],
112
+ ["What could this data be used for?"],
113
+ ]
114
+ session_hash = gr.Textbox(visible=False, value=request.session_hash)
115
+ data_source = gr.Textbox(visible=False, value='file_upload')
116
+ schema = gr.Textbox(visible=False, value='')
117
+ titles = gr.Textbox(value=process_message[2], interactive=False, visible=False)
118
+ bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
119
+ chat = gr.ChatInterface(
120
+ fn=chatbot_func,
121
+ type='messages',
122
+ chatbot=bot,
123
+ title="Chat with your data file",
124
+ concurrency_limit=None,
125
+ examples=example_questions,
126
+ additional_inputs=[session_hash, data_source, titles, schema]
127
+ )
128
+
129
+ def process_upload(upload_value, session_hash):
130
+ if upload_value:
131
+ process_message = process_data_upload(upload_value, session_hash)
132
+ return process_message
133
+
134
+
135
+ if __name__ == "__main__":
136
  demo.launch()
tools/tools.py CHANGED
@@ -10,7 +10,7 @@ def tools_call(session_hash, data_source, titles):
10
  {
11
  "type": "function",
12
  "function": {
13
- "name": "sqlite_query_func",
14
  "description": f"""This is a tool useful to query a SQLite table called 'data_source' with the following Columns: {titles_string}.
15
  There may also be more columns in the table if the number of columns is too large to process.
16
  This function also saves the results of the query to csv file called query.csv.""",
@@ -34,7 +34,7 @@ def tools_call(session_hash, data_source, titles):
34
  {
35
  "type": "function",
36
  "function": {
37
- "name": "sql_query_func",
38
  "description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {titles_string}.
39
  There may also be more tables in the database if the number of tables is too large to process.
40
  This function also saves the results of the query to csv file called query.csv.""",
@@ -58,14 +58,14 @@ def tools_call(session_hash, data_source, titles):
58
  {
59
  "type": "function",
60
  "function": {
61
- "name": "doc_db_query_func",
62
  "description": f"""This is a tool useful to build an aggregation pipeline to query a MongoDB NoSQL document database with the following collections, {titles_string}.
63
  There may also be more collections in the database if the number of tables is too large to process.
64
  This function also saves the results of the query to a csv file called query.csv.""",
65
  "parameters": {
66
  "type": "object",
67
  "properties": {
68
- "aggregation_pipeline": {
69
  "type": "string",
70
  "description": "The MongoDB aggregation pipeline to use in the search. Infer this from the user's message. It should be a question or a statement."
71
  },
@@ -74,7 +74,7 @@ def tools_call(session_hash, data_source, titles):
74
  "description": "The MongoDB collection to use in the search. Infer this from the user's message. It should be a question or a statement.",
75
  }
76
  },
77
- "required": ["aggregation_pipeline","db_collection"],
78
  },
79
  },
80
  },
@@ -83,19 +83,19 @@ def tools_call(session_hash, data_source, titles):
83
  {
84
  "type": "function",
85
  "function": {
86
- "name": "graphql_query_func",
87
  "description": f"""This is a tool useful to build a GraphQL query for a GraphQL API endpoint with the following types, {titles_string}.
88
  There may also be more types in the GraphQL endpoint if the number of types is too large to process.
89
  This function also saves the results of the query to a csv file called query.csv.""",
90
  "parameters": {
91
  "type": "object",
92
  "properties": {
93
- "graphql_query": {
94
  "type": "string",
95
  "description": "The GraphQL query to use in the search. Infer this from the user's message. It should be a question or a statement."
96
  }
97
  },
98
- "required": ["graphql_query"],
99
  },
100
  },
101
  },
 
10
  {
11
  "type": "function",
12
  "function": {
13
+ "name": "query_func",
14
  "description": f"""This is a tool useful to query a SQLite table called 'data_source' with the following Columns: {titles_string}.
15
  There may also be more columns in the table if the number of columns is too large to process.
16
  This function also saves the results of the query to csv file called query.csv.""",
 
34
  {
35
  "type": "function",
36
  "function": {
37
+ "name": "query_func",
38
  "description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {titles_string}.
39
  There may also be more tables in the database if the number of tables is too large to process.
40
  This function also saves the results of the query to csv file called query.csv.""",
 
58
  {
59
  "type": "function",
60
  "function": {
61
+ "name": "query_func",
62
  "description": f"""This is a tool useful to build an aggregation pipeline to query a MongoDB NoSQL document database with the following collections, {titles_string}.
63
  There may also be more collections in the database if the number of tables is too large to process.
64
  This function also saves the results of the query to a csv file called query.csv.""",
65
  "parameters": {
66
  "type": "object",
67
  "properties": {
68
+ "queries": {
69
  "type": "string",
70
  "description": "The MongoDB aggregation pipeline to use in the search. Infer this from the user's message. It should be a question or a statement."
71
  },
 
74
  "description": "The MongoDB collection to use in the search. Infer this from the user's message. It should be a question or a statement.",
75
  }
76
  },
77
+ "required": ["queries","db_collection"],
78
  },
79
  },
80
  },
 
83
  {
84
  "type": "function",
85
  "function": {
86
+ "name": "query_func",
87
  "description": f"""This is a tool useful to build a GraphQL query for a GraphQL API endpoint with the following types, {titles_string}.
88
  There may also be more types in the GraphQL endpoint if the number of types is too large to process.
89
  This function also saves the results of the query to a csv file called query.csv.""",
90
  "parameters": {
91
  "type": "object",
92
  "properties": {
93
+ "queries": {
94
  "type": "string",
95
  "description": "The GraphQL query to use in the search. Infer this from the user's message. It should be a question or a statement."
96
  }
97
  },
98
+ "required": ["queries"],
99
  },
100
  },
101
  },