Dochee commited on
Commit
a6c61d9
1 Parent(s): 71fdbd9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
2
+ import torch
3
+
4
+
5
+
6
+ chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
7
+
8
+ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
9
+
10
+
11
+
12
+ #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
13
+
14
+
15
+ #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
16
+
17
+
18
+ def converse(user_input, chat_history=[]):
19
+
20
+ user_input_ids = chat_tkn(user_input +
21
+ chat_tkn.eos_token,
22
+ return_tensors='pt').input_ids
23
+
24
+ # keep history in the tensor
25
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history),
26
+ user_input_ids],
27
+ dim=-1)
28
+
29
+ # get response
30
+ chat_history = mdl.generate(bot_input_ids,
31
+ max_length=1000,
32
+ pad_token_id = chat_tkn.eos_token_id).tolist()
33
+ print (chat_history)
34
+
35
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
36
+
37
+ print("starting to print response")
38
+ print(response)
39
+
40
+ # html for display
41
+ html = "<div class='mybot'>"
42
+
43
+ for x, mesg in enumerate(response):
44
+
45
+ if x%2!=0 :
46
+ mesg = "Alicia:"+mesg
47
+ clazz = "alicia"
48
+ else :
49
+ clazz = "user"
50
+
51
+ print("value of x")
52
+ print(x)
53
+ print("message")
54
+
55
+ print (mesg)
56
+
57
+ html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
58
+ html += "</div>"
59
+ print(html)
60
+ return html, chat_history
61
+
62
+
63
+
64
+
65
+ import gradio as grad
66
+
67
+ css = """
68
+ .mychat {display:flex;flex-direction:column}
69
+ .mesg {padding:5px;margin-bottom:5px;border-
70
+ radius:5px;width:75%}
71
+ .mesg.user {background-color:lightblue;color:white}
72
+ .mesg.alicia {background-color:orange;color:white,align-
73
+ self:self-end}
74
+ .footer {display:none !important}
75
+ """
76
+ text = grad.inputs.Textbox(placeholder="Lets chat")
77
+
78
+ grad.Interface(fn=converse,
79
+ theme="default",
80
+ inputs=[text, "state"],
81
+ outputs=["html", "state"],
82
+ css=css).launch()
83
+