DialoGPT-small / app.py
Ahsen Khaliq
Update app.py
fcabd91
raw history blame
No virus
2.58 kB
import os
os.system('pip install gradio==2.3.5b0')
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
def dialogpt(text):
history = gr.get_state() or []
# encode the new user input, add the eos_token and return a tensor in Pytorch
for step in range(50000):
new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
# generated a response while limiting the total chat history to 1000 tokens,
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
history.append((text, response))
gr.set_state(history)
# pretty print last ouput tokens from bot
html = "<div class='chatbot'>"
for user_msg, resp_msg in history:
html += f"<div class='user_msg'>{user_msg}</div>"
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return html
inputs = gr.inputs.Textbox(lines=1, label="Input Text")
outputs = gr.outputs.Textbox(label="DialoGPT")
title = "DialoGPT"
description = "Gradio demo for Microsoft DialoGPT with Hugging Face transformers. To use it, simply input text or click one of the examples text to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1911.00536'>DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation</a> | <a href='https://github.com/microsoft/DialoGPT'>Github Repo</a> | <a href='https://huggingface.co/microsoft/DialoGPT-large'>Hugging Face DialoGPT-large</a></p>"
examples = [
["Hi, how are you?"],
["How far away is the moon?"],
]
gr.Interface(dialogpt, inputs, "html", title=title, description=description, article=article, examples=examples,css="""
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""").launch(debug=True)