Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import spacy | |
import time | |
import torch | |
from spacy.cli.download import download | |
from transformers import pipeline | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
# define function to generate chatbot response | |
def generate_response(user_input): | |
# add tokens to user input text | |
user_input = (' '.join(['[BOS]', user_input.strip().lower(), '[BOT]'])) | |
# encode input | |
input_ids = tokenizer.encode(user_input, return_tensors='pt', add_special_tokens=True).to(device) | |
# generate top_p (nucleus) sampling | |
sample_outputs = model.generate( | |
input_ids, | |
do_sample=True, | |
max_length=50, | |
top_k=30, | |
top_p=0.95, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
early_stopping=True, | |
temperature=.7, | |
num_beams=6 | |
) | |
for i, sample_output in enumerate(sample_outputs): | |
# obtain list of tokens | |
output_tokens = sample_outputs[0].tolist() | |
# find location of [BOT] token | |
bot_token_id = 50263 | |
try: | |
bot_token_index = output_tokens.index(bot_token_id) | |
# print decoded text after the [BOT] token | |
decoded_text = tokenizer.decode(output_tokens[bot_token_index + 1:], skip_special_tokens=True) | |
response = (postprocess_text(decoded_text)) # call function to postprocess response | |
return(response) # return chatbot response | |
# if [BOT] token is not found | |
except ValueError: | |
print('Unable to find [BOT] token.') | |
# define function to postprocess generated chatbot text | |
def postprocess_text(text): | |
try: | |
# construct doc object and create list of sentences | |
doc = nlp(text) | |
sentences = list(doc.sents) | |
# capitalize first letter of each sentence | |
# only consider a sentence if greater than 3 chars | |
capitalized_sentences = [] | |
for sent in sentences: | |
if len(sent.text.strip()) >= 3: | |
sentence = sent.text.strip() | |
if not sentence.endswith('.') and not sentence.endswith('?'): | |
sentence += '.' | |
capitalized_sentences.append(sentence.capitalize()) | |
# if response is more than one sentence, only return first two sentences | |
if len(capitalized_sentences) == 1: | |
response = capitalized_sentences[0] | |
elif len(capitalized_sentences) > 1: | |
response = ' '.join(capitalized_sentences[:2]) | |
else: | |
response = "Sorry, I don't understand your question. Can you try asking it in another way?" | |
# return response | |
return response.strip() | |
except: | |
return "Sorry, I don't understand your question. Can you try asking it in another way?" | |
# load english language model | |
nlp = spacy.load('en_core_web_sm') | |
# saved model location | |
saved_model = "jeraimondi/chatbot-ubuntu-gpt2" | |
# load previously trained model and tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained(saved_model) | |
model = GPT2LMHeadModel.from_pretrained(saved_model) | |
# set model to use GPUs if available in runtime session | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
# define and launch gradio interface | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot( | |
bubble_full_width=False, | |
avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))), | |
avatar_css={"width": "150px", "height": "150px"} | |
) | |
msg = gr.Textbox( | |
show_label=False, | |
placeholder="Enter question and press enter", | |
container=False | |
) | |
clear = gr.Button("Clear") | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history): | |
user_message = history[-1][0] | |
bot_message = generate_response(user_message) | |
history[-1][1] = "" | |
for character in bot_message: | |
history[-1][1] += character | |
time.sleep(0.05) | |
yield history | |
def vote(data: gr.LikeData): | |
if data.liked: | |
print("You upvoted this response: " + data.value) | |
else: | |
print("You downvoted this response: " + data.value) | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
chatbot.like(vote, None, None) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue() | |
demo.launch() |