Tonic commited on
Commit
7339879
1 Parent(s): d483f6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -24
app.py CHANGED
@@ -20,28 +20,30 @@ examples = [["How are you?"]]
20
 
21
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
22
  tokenizer.padding_side = 'left'
 
23
 
24
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
25
-
26
- def predict(input, history=[]):
27
- new_user_input_ids = tokenizer.encode(input, return_tensors="pt")
28
-
29
- bot_input_ids = torch.cat([torch.tensor(history), new_user_input_ids], dim=-1) if history else new_user_input_ids
30
-
31
- chat_history_ids = model.generate(bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id)
32
-
33
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
34
-
35
- return response
36
-
37
- iface = gr.Interface(
38
- fn=predict,
39
- title=title,
40
- description=description,
41
- examples=examples,
42
- inputs="text",
43
- outputs="text",
44
- theme="ParityError/Anime",
45
- )
46
-
47
- iface.launch()
 
 
20
 
21
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
22
  tokenizer.padding_side = 'left'
23
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
24
 
25
+ class ChatBot:
26
+ def __init__(self):
27
+ self.history = []
28
+
29
+ def predict(self, input):
30
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
31
+ bot_input_ids = torch.cat([torch.tensor(self.history), new_user_input_ids], dim=-1) if self.history else new_user_input_ids
32
+ chat_history_ids = model.generate(bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id)
33
+ self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist())
34
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
35
+ return response
36
+
37
+ bot = ChatBot()
38
+
39
+ iface = gr.Interface(
40
+ fn=bot.predict,
41
+ title=title,
42
+ description=description,
43
+ examples=examples,
44
+ inputs="text",
45
+ outputs="text",
46
+ theme="ParityError/Anime",
47
+ )
48
+
49
+ iface.launch()