bixentemal commited on
Commit
e4c96d3
1 Parent(s): 0fe8e5b

Chatbot DialoGPT

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -7,16 +7,22 @@ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
7
 
8
  # chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
9
  # mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
 
10
  def converse(user_input, chat_history=[]):
11
  user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
 
12
  # keep history in the tensor
13
  bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
 
14
  # get response
15
  chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
16
  print(chat_history)
 
17
  response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
 
18
  print("starting to print response")
19
  print(response)
 
20
  # html for display
21
  html = "<div class='mybot'>"
22
  for x, mesg in enumerate(response):
@@ -25,11 +31,13 @@ def converse(user_input, chat_history=[]):
25
  clazz = "alicia"
26
  else:
27
  clazz = "user"
28
- print("value of x")
29
- print(x)
30
- print("message")
31
- print(mesg)
32
- html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
 
 
33
  html += "</div>"
34
  print(html)
35
  return html, chat_history
@@ -45,4 +53,8 @@ css = """
45
  .footer {display:none !important}
46
  """
47
  text = grad.inputs.Textbox(placeholder="Lets chat")
48
- grad.Interface(fn=converse, theme="default", inputs=[text, "state"], outputs=["html", "state"], css=css).launch()
 
 
 
 
 
7
 
8
  # chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
9
  # mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
10
+
11
  def converse(user_input, chat_history=[]):
12
  user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
13
+
14
  # keep history in the tensor
15
  bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
16
+
17
  # get response
18
  chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
19
  print(chat_history)
20
+
21
  response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
22
+
23
  print("starting to print response")
24
  print(response)
25
+
26
  # html for display
27
  html = "<div class='mybot'>"
28
  for x, mesg in enumerate(response):
 
31
  clazz = "alicia"
32
  else:
33
  clazz = "user"
34
+
35
+ print("value of x")
36
+ print(x)
37
+ print("message")
38
+ print(mesg)
39
+
40
+ html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
41
  html += "</div>"
42
  print(html)
43
  return html, chat_history
 
53
  .footer {display:none !important}
54
  """
55
  text = grad.inputs.Textbox(placeholder="Lets chat")
56
+ grad.Interface(fn=converse,
57
+ theme="default",
58
+ inputs=[text, "state"],
59
+ outputs=["html", "state"],
60
+ css=css).launch()