bixentemal commited on
Commit
0fe8e5b
1 Parent(s): 53ca93c

Chatbot DialoGPT

Browse files
Files changed (1) hide show
  1. app.py +47 -15
app.py CHANGED
@@ -1,16 +1,48 @@
1
- from transformers import T5ForConditionalGeneration, T5Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as grad
3
- text2text_tkn= T5Tokenizer.from_pretrained("t5-small")
4
- mdl = T5ForConditionalGeneration.from_pretrained("t5-small")
5
- def text2text_paraphrase(sentence1,sentence2):
6
- inp1 = "mrpc sentence1: "+sentence1
7
- inp2 = "sentence2: "+sentence2
8
- combined_inp=inp1+" "+inp2
9
- enc = text2text_tkn(combined_inp, return_tensors="pt")
10
- tokens = mdl.generate(**enc)
11
- response=text2text_tkn.batch_decode(tokens)
12
- return response
13
- sent1=grad.Textbox(lines=1, label="Sentence1", placeholder="Text in English")
14
- sent2=grad.Textbox(lines=1, label="Sentence2", placeholder="Text in English")
15
- out=grad.Textbox(lines=1, label="Whether the sentence is acceptable or not")
16
- grad.Interface(text2text_paraphrase, inputs=[sent1,sent2], outputs=out).launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BlenderbotForConditionalGeneration
2
+ import torch
3
+
4
+ chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
5
+ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
6
+
7
+
8
+ # chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
9
+ # mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
10
+ def converse(user_input, chat_history=[]):
11
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
12
+ # keep history in the tensor
13
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
14
+ # get response
15
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
16
+ print(chat_history)
17
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
18
+ print("starting to print response")
19
+ print(response)
20
+ # html for display
21
+ html = "<div class='mybot'>"
22
+ for x, mesg in enumerate(response):
23
+ if x % 2 != 0:
24
+ mesg = "Alicia:" + mesg
25
+ clazz = "alicia"
26
+ else:
27
+ clazz = "user"
28
+ print("value of x")
29
+ print(x)
30
+ print("message")
31
+ print(mesg)
32
+ html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
33
+ html += "</div>"
34
+ print(html)
35
+ return html, chat_history
36
+
37
+
38
  import gradio as grad
39
+
40
+ css = """
41
+ .mychat {display:flex;flex-direction:column}
42
+ .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
43
+ .mesg.user {background-color:lightblue;color:white}
44
+ .mesg.alicia {background-color:orange;color:white,align-self:self-end}
45
+ .footer {display:none !important}
46
+ """
47
+ text = grad.inputs.Textbox(placeholder="Lets chat")
48
+ grad.Interface(fn=converse, theme="default", inputs=[text, "state"], outputs=["html", "state"], css=css).launch()