Update app.py
Browse files
app.py
CHANGED
@@ -104,8 +104,9 @@ def sqlquery(input, history=[]):
|
|
104 |
#sql_outputs = sql_model.generate(**sql_encoding)
|
105 |
#sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
|
106 |
|
|
|
107 |
# append the new user input tokens to the chat history
|
108 |
-
bot_input_ids = torch.cat([torch.LongTensor(history),
|
109 |
|
110 |
# generate a response
|
111 |
history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
|
|
|
104 |
#sql_outputs = sql_model.generate(**sql_encoding)
|
105 |
#sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
|
106 |
|
107 |
+
sql_input_ids = sql_encoding["input_ids"]
|
108 |
# append the new user input tokens to the chat history
|
109 |
+
bot_input_ids = torch.cat([torch.LongTensor(history), sql_input_ids], dim=-1)
|
110 |
|
111 |
# generate a response
|
112 |
history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
|