Tonic commited on
Commit
aae1bd9
1 Parent(s): 700caa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -10,14 +10,16 @@ class ChatBot:
10
  def __init__(self):
11
  self.history = []
12
 
13
- def predict(self, input):
14
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
15
- flat_history = [item for sublist in self.history for item in sublist]
16
- bot_input_ids = torch.cat([torch.tensor(flat_history), new_user_input_ids], dim=-1) if self.history else new_user_input_ids
17
- chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
18
- self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
19
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
20
- return response
 
 
21
 
22
  bot = ChatBot()
23
 
 
10
  def __init__(self):
11
  self.history = []
12
 
13
+ def predict(self, input):
14
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
15
+ flat_history = [item for sublist in self.history for item in sublist]
16
+ flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0) # convert list to 2-D tensor
17
+ bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self.history else new_user_input_ids
18
+ chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
19
+ self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
20
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
21
+ return response
22
+
23
 
24
  bot = ChatBot()
25