| from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration | |
| import torch | |
| import gradio as gr | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = "./blenderbot-1B-distill" | |
| tokenizer = BlenderbotTokenizer.from_pretrained(model_name) | |
| model = BlenderbotForConditionalGeneration.from_pretrained(model_name) | |
| model.to(device) | |
| def get_reply(response, username = None, histories = {}): | |
| if username == None or username == "": return "<div class='chatbot'>Enter a username</div>", histories | |
| history = histories.get(username, []) | |
| history.append(response) | |
| if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!","Hello", "Hi", "hello")): | |
| histories[username] = [] | |
| return "<div class='chatbot'>Chatbot restarted</div>", histories | |
| if len(history) > 4: history = history[-4:] | |
| inputs = tokenizer(" ".join(history), return_tensors="pt") | |
| inputs.to(device) | |
| outputs = model.generate(**inputs) | |
| reply = tokenizer.decode(outputs[0][1:-1]).strip() | |
| history.append(reply) | |
| html = "<div class='chatbot'>" | |
| for m, msg in enumerate(history): | |
| cls = "user" if m%2 == 0 else "bot" | |
| html += "<div class='msg {}'> {}</div>".format(cls, msg) | |
| html += "</div>" | |
| histories[username] = history | |
| return html, histories | |
| css = """ | |
| .chatbox {display:flex;flex-direction:column} | |
| .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
| .msg.user {background-color:cornflowerblue;color:white} | |
| .msg.bot {background-color:lightgray;align-self:self-end} | |
| .footer {display:none !important} | |
| """ | |
| gr.Interface(fn=get_reply, | |
| theme="default", | |
| inputs=[gr.inputs.Textbox(placeholder="How are you?"), | |
| gr.inputs.Textbox(label="Username"), | |
| "state"], | |
| outputs=["html", "state"], | |
| css=css).launch(debug=True, enable_queue=True) |