richardr1126
commited on
Commit
•
af675f5
1
Parent(s):
99e7c03
Choose 1st response
Browse files
app.py
CHANGED
@@ -144,6 +144,7 @@ def generate_dummy_db(db_info, question):
|
|
144 |
response_text = response['choices'][0]['message']['content']
|
145 |
|
146 |
db_code = extract_db_code(response_text)
|
|
|
147 |
|
148 |
return db_code
|
149 |
|
@@ -183,6 +184,9 @@ def test_query_on_dummy_db(db_code, query):
|
|
183 |
|
184 |
|
185 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
|
|
|
|
|
|
186 |
stop_token_ids = tok.convert_tokens_to_ids(["###"])
|
187 |
class StopOnTokens(StoppingCriteria):
|
188 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
@@ -237,7 +241,7 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
237 |
responses.append(query)
|
238 |
|
239 |
# Choose a random response from responses
|
240 |
-
output = responses[
|
241 |
|
242 |
if log:
|
243 |
# Log the request to Firestore
|
|
|
144 |
response_text = response['choices'][0]['message']['content']
|
145 |
|
146 |
db_code = extract_db_code(response_text)
|
147 |
+
print(db_code)
|
148 |
|
149 |
return db_code
|
150 |
|
|
|
184 |
|
185 |
|
186 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
187 |
+
if num_return_sequences >= num_beams:
|
188 |
+
gr.Warning("Num return sequences must be less than or equal to num beams.")
|
189 |
+
|
190 |
stop_token_ids = tok.convert_tokens_to_ids(["###"])
|
191 |
class StopOnTokens(StoppingCriteria):
|
192 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
241 |
responses.append(query)
|
242 |
|
243 |
# Choose a random response from responses
|
244 |
+
output = responses[0] if len(responses) > 0 else "###"
|
245 |
|
246 |
if log:
|
247 |
# Log the request to Firestore
|