teaevo commited on
Commit
b54b3e0
1 Parent(s): 9d8ae89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -104,8 +104,9 @@ def sqlquery(input, history=[]):
104
  #sql_outputs = sql_model.generate(**sql_encoding)
105
  #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
106
 
 
107
  # append the new user input tokens to the chat history
108
- bot_input_ids = torch.cat([torch.LongTensor(history), **sql_encoding], dim=-1)
109
 
110
  # generate a response
111
  history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
 
104
  #sql_outputs = sql_model.generate(**sql_encoding)
105
  #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
106
 
107
+ sql_input_ids = sql_encoding["input_ids"]
108
  # append the new user input tokens to the chat history
109
+ bot_input_ids = torch.cat([torch.LongTensor(history), sql_input_ids], dim=-1)
110
 
111
  # generate a response
112
  history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()