richardr1126 commited on
Commit
99e7c03
1 Parent(s): e6db544
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -7,6 +7,7 @@ import time
7
  import re
8
  import platform
9
  import openai
 
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
@@ -127,9 +128,9 @@ def extract_db_code(text):
127
  matches = re.findall(pattern, text, re.DOTALL)
128
  return [match.strip() for match in matches]
129
 
130
- def generate_dummy_db(db_info, question, query):
131
- pre_prompt = "Generate a SQLite database with dummy data for this database, output the SQL code in a SQL code block. Make sure you add dummy data relevant to the question and query.\n\n"
132
- prompt = pre_prompt + db_info + "\n\nQuestion: " + question + "\nQuery: " + query
133
 
134
  while True:
135
  try:
@@ -212,6 +213,10 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
212
 
213
  tokens = m.generate(**generate_kwargs)
214
 
 
 
 
 
215
  responses = []
216
  for response in tokens:
217
  response_text = tok.decode(response, skip_special_tokens=True)
@@ -223,7 +228,6 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
223
  if (num_return_sequences > 1):
224
  query = query.replace("\n", " ").replace("\t", " ").strip()
225
  # Test against dummy database
226
- db_code = generate_dummy_db(db_info, input_message, query)
227
  success = test_query_on_dummy_db(db_code, query)
228
  # Format again
229
  query = format(query) if format_sql else query
@@ -232,8 +236,8 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
232
  else:
233
  responses.append(query)
234
 
235
- # Choose the first response
236
- output = responses[0] if responses else ""
237
 
238
  if log:
239
  # Log the request to Firestore
 
7
  import re
8
  import platform
9
  import openai
10
+ import random
11
  from transformers import (
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
 
128
  matches = re.findall(pattern, text, re.DOTALL)
129
  return [match.strip() for match in matches]
130
 
131
+ def generate_dummy_db(db_info, question):
132
+ pre_prompt = "Generate a SQLite database with dummy data for this database, output the SQL code in a SQL code block. Make sure you add dummy data relevant to the question.\n\n"
133
+ prompt = pre_prompt + db_info + "\n\nQuestion: " + question
134
 
135
  while True:
136
  try:
 
213
 
214
  tokens = m.generate(**generate_kwargs)
215
 
216
+ db_code = None
217
+ if (num_return_sequences > 1):
218
+ db_code = generate_dummy_db(db_info, input_message)
219
+
220
  responses = []
221
  for response in tokens:
222
  response_text = tok.decode(response, skip_special_tokens=True)
 
228
  if (num_return_sequences > 1):
229
  query = query.replace("\n", " ").replace("\t", " ").strip()
230
  # Test against dummy database
 
231
  success = test_query_on_dummy_db(db_code, query)
232
  # Format again
233
  query = format(query) if format_sql else query
 
236
  else:
237
  responses.append(query)
238
 
239
+ # Choose a random response from responses
240
+ output = responses[random.randint(0, len(responses)-1)] if len(responses) > 0 else "###"
241
 
242
  if log:
243
  # Log the request to Firestore