jeraimondi's picture
Update app.py
b567a38
import gradio as gr
import os
import spacy
import time
import torch
from spacy.cli.download import download
from transformers import pipeline
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# define function to generate chatbot response
def generate_response(user_input):
# add tokens to user input text
user_input = (' '.join(['[BOS]', user_input.strip().lower(), '[BOT]']))
# encode input
input_ids = tokenizer.encode(user_input, return_tensors='pt', add_special_tokens=True).to(device)
# generate top_p (nucleus) sampling
sample_outputs = model.generate(
input_ids,
do_sample=True,
max_length=50,
top_k=30,
top_p=0.95,
num_return_sequences=1,
no_repeat_ngram_size=2,
early_stopping=True,
temperature=.7,
num_beams=6
)
for i, sample_output in enumerate(sample_outputs):
# obtain list of tokens
output_tokens = sample_outputs[0].tolist()
# find location of [BOT] token
bot_token_id = 50263
try:
bot_token_index = output_tokens.index(bot_token_id)
# print decoded text after the [BOT] token
decoded_text = tokenizer.decode(output_tokens[bot_token_index + 1:], skip_special_tokens=True)
response = (postprocess_text(decoded_text)) # call function to postprocess response
return(response) # return chatbot response
# if [BOT] token is not found
except ValueError:
print('Unable to find [BOT] token.')
# define function to postprocess generated chatbot text
def postprocess_text(text):
try:
# construct doc object and create list of sentences
doc = nlp(text)
sentences = list(doc.sents)
# capitalize first letter of each sentence
# only consider a sentence if greater than 3 chars
capitalized_sentences = []
for sent in sentences:
if len(sent.text.strip()) >= 3:
sentence = sent.text.strip()
if not sentence.endswith('.') and not sentence.endswith('?'):
sentence += '.'
capitalized_sentences.append(sentence.capitalize())
# if response is more than one sentence, only return first two sentences
if len(capitalized_sentences) == 1:
response = capitalized_sentences[0]
elif len(capitalized_sentences) > 1:
response = ' '.join(capitalized_sentences[:2])
else:
response = "Sorry, I don't understand your question. Can you try asking it in another way?"
# return response
return response.strip()
except:
return "Sorry, I don't understand your question. Can you try asking it in another way?"
# load english language model
nlp = spacy.load('en_core_web_sm')
# saved model location
saved_model = "jeraimondi/chatbot-ubuntu-gpt2"
# load previously trained model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(saved_model)
model = GPT2LMHeadModel.from_pretrained(saved_model)
# set model to use GPUs if available in runtime session
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# define and launch gradio interface
with gr.Blocks() as demo:
chatbot = gr.Chatbot(
bubble_full_width=False,
avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))),
avatar_css={"width": "150px", "height": "150px"}
)
msg = gr.Textbox(
show_label=False,
placeholder="Enter question and press enter",
container=False
)
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
user_message = history[-1][0]
bot_message = generate_response(user_message)
history[-1][1] = ""
for character in bot_message:
history[-1][1] += character
time.sleep(0.05)
yield history
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
chatbot.like(vote, None, None)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch()