teaevo commited on
Commit
829e215
·
1 Parent(s): 8a4fd9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -24,9 +24,9 @@ tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
24
  model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
25
 
26
  # Load the SQL Model
27
- model_name = "microsoft/tapex-large-finetuned-wtq"
28
- sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
29
- sql_model = BartForConditionalGeneration.from_pretrained(model_name)
30
 
31
  data = {
32
  "year": [1896, 1900, 1904, 2004, 2008, 2012],
@@ -34,8 +34,11 @@ data = {
34
  }
35
  table = pd.DataFrame.from_dict(data)
36
 
 
 
37
  def predict(input, history=[]):
38
 
 
39
  # Check if the user input is a question
40
  is_question = "?" in input
41
 
@@ -45,11 +48,16 @@ def predict(input, history=[]):
45
  # append the new user input tokens to the chat history
46
  bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
47
 
 
 
 
 
 
48
  # generate a response
49
  history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
50
 
51
  # convert the tokens to text, and then split the responses into the right format
52
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
53
  response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
54
  return response, history
55
 
 
24
  model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
25
 
26
  # Load the SQL Model
27
+ sql_model_name = "microsoft/tapex-large-finetuned-wtq"
28
+ sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
29
+ sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
30
 
31
  data = {
32
  "year": [1896, 1900, 1904, 2004, 2008, 2012],
 
34
  }
35
  table = pd.DataFrame.from_dict(data)
36
 
37
+ sql_response = None
38
+
39
  def predict(input, history=[]):
40
 
41
+ global sql_response
42
  # Check if the user input is a question
43
  is_question = "?" in input
44
 
 
48
  # append the new user input tokens to the chat history
49
  bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
50
 
51
+ if is_question:
52
+ sql_encoding = sql_tokenizer(table=table, query=user_message, return_tensors="pt")
53
+ sql_outputs = sql_model.generate(**encoding)
54
+ sql_response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
55
+
56
  # generate a response
57
  history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
58
 
59
  # convert the tokens to text, and then split the responses into the right format
60
+ response = tokenizer.decode(history[0]).split("<|endoftext|>") + sql_response
61
  response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
62
  return response, history
63