richardr1126 commited on
Commit
3c3d942
1 Parent(s): 648b445

Add threading for open ai api

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  import platform
9
  import openai
10
  import random
 
11
  from transformers import (
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
@@ -130,7 +131,7 @@ def extract_db_code(text):
130
 
131
  def generate_dummy_db(db_info, question):
132
  pre_prompt = "Generate a SQLite database with dummy data for this database from the DB Layout. Make sure you add dummy data relevant to the Question and don't write any SELECT statements or actual queries."
133
- prompt = pre_prompt + "\n\nDB Layout:" db_info + "\n\nQuestion: " + question
134
 
135
  while True:
136
  try:
@@ -216,11 +217,15 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
216
  do_sample=do_sample,
217
  )
218
 
 
 
 
 
 
219
  tokens = m.generate(**generate_kwargs)
220
 
221
- db_code = None
222
- if (num_return_sequences > 1):
223
- db_code = generate_dummy_db(db_info, input_message)
224
 
225
  responses = []
226
  for response in tokens:
 
8
  import platform
9
  import openai
10
  import random
11
+ import concurrent.futures
12
  from transformers import (
13
  AutoModelForCausalLM,
14
  AutoTokenizer,
 
131
 
132
  def generate_dummy_db(db_info, question):
133
  pre_prompt = "Generate a SQLite database with dummy data for this database from the DB Layout. Make sure you add dummy data relevant to the Question and don't write any SELECT statements or actual queries."
134
+ prompt = pre_prompt + "\n\nDB Layout:" + db_info + "\n\nQuestion: " + question
135
 
136
  while True:
137
  try:
 
217
  do_sample=do_sample,
218
  )
219
 
220
+ db_code_future = None
221
+ if num_return_sequences > 1:
222
+ with concurrent.futures.ThreadPoolExecutor() as executor:
223
+ db_code_future = executor.submit(generate_dummy_db, db_info, input_message)
224
+
225
  tokens = m.generate(**generate_kwargs)
226
 
227
+ if db_code_future:
228
+ db_code = db_code_future.result()
 
229
 
230
  responses = []
231
  for response in tokens: