File size: 2,183 Bytes
b19af15
bad32bd
b19af15
2524827
 
9e350ba
 
 
f8beb9a
9e350ba
 
f8beb9a
2524827
 
6091881
f8beb9a
2524827
0550020
f8beb9a
 
2524827
 
 
f8beb9a
b9146b8
2524827
33ca492
 
f8beb9a
 
649ab52
 
e889b71
cccac01
 
2d99d69
 
cccac01
9f25e80
649ab52
 
2524827
649ab52
e889b71
9f25e80
f8beb9a
72bd468
2524827
5e2ff12
cdb2aab
5e2ff12
f8beb9a
5f468a8
649ab52
 
5f468a8
f8beb9a
 
515c60f
cdb2aab
f8beb9a
cdb2aab
f8beb9a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch


chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")


#chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
#mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")

def converse(user_input, chat_history=[]):
    
    user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids

    # create a combined tensor with chat history
    bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)

    # generate a response 
    chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
    print (chat_history)

    # convert the tokens to text, and then split the responses into lines
    response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
    #response.remove("")
    print("starting to print response")
    print(response)
    
    # write some HTML
    html = "<div class='mybot'>"
    for x, mesg in enumerate(response):
        if x%2!=0 :
           mesg="Alicia:"+mesg
           clazz="alicia"
        else :
           clazz="user"
        
        
        print("value of x")
        print(x)
        print("message")
        print (mesg)
        
        html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
    html += "</div>"
    print(html)
    return html, chat_history

import gradio as grad

css = """
.mychat {display:flex;flex-direction:column}
.mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
.mesg.user {background-color:lightblue;color:white}
.mesg.alicia {background-color:orange;color:white,align-self:self-end}
.footer {display:none !important}
"""
text=grad.inputs.Textbox(placeholder="Lets chat together")
grad.Interface(fn=converse,
             theme="default",
             inputs=[text, "state"],
             outputs=["html", "state"],
             css=css).launch()