tareknaous commited on
Commit
60159a1
1 Parent(s): 92ede04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -11,7 +11,7 @@ def chat(message, history):
11
  history = history or []
12
  new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
13
 
14
- if len(history) > 0:
15
  for i in range(0,len(history)):
16
  encoded_message = tokenizer.encode(history[i][0] + tokenizer.eos_token, return_tensors='pt')
17
  encoded_response = tokenizer.encode(history[i][1] + tokenizer.eos_token, return_tensors='pt')
@@ -24,7 +24,20 @@ def chat(message, history):
24
 
25
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
26
 
27
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  bot_input_ids = new_user_input_ids
29
 
30
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, do_sample=True, top_p=0.9, temperature=0.8, pad_token_id=tokenizer.eos_token_id)
11
  history = history or []
12
  new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
13
 
14
+ if len(history) > 0 and len(history) < 2:
15
  for i in range(0,len(history)):
16
  encoded_message = tokenizer.encode(history[i][0] + tokenizer.eos_token, return_tensors='pt')
17
  encoded_response = tokenizer.encode(history[i][1] + tokenizer.eos_token, return_tensors='pt')
24
 
25
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
26
 
27
+ elif len(history) >= 2:
28
+ for i in range(len(history)-1, len(history)):
29
+ encoded_message = tokenizer.encode(history[i][0] + tokenizer.eos_token, return_tensors='pt')
30
+ encoded_response = tokenizer.encode(history[i][1] + tokenizer.eos_token, return_tensors='pt')
31
+ if i == (len(history)-1):
32
+ chat_history_ids = encoded_message
33
+ chat_history_ids = torch.cat([chat_history_ids,encoded_response], dim=-1)
34
+ else:
35
+ chat_history_ids = torch.cat([chat_history_ids,encoded_message], dim=-1)
36
+ chat_history_ids = torch.cat([chat_history_ids,encoded_response], dim=-1)
37
+
38
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
39
+
40
+ elif len(history) == 0:
41
  bot_input_ids = new_user_input_ids
42
 
43
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, do_sample=True, top_p=0.9, temperature=0.8, pad_token_id=tokenizer.eos_token_id)