gpt-2-chatbot / app.py
jdoexbox360's picture
Update app.py
a9f06e3
raw
history blame contribute delete
No virus
1.37 kB
import gradio as gr
from gradio.components import Slider, Textbox, Radio
import tensorflow as tf
from transformers import GPT2LMHeadModel, GPT2Tokenizer
global tokenizer, model, script_speaker_name, script_responder_name, convo
tokenizer = GPT2Tokenizer.from_pretrained("ethzanalytics/ai-msgbot-gpt2-XL-dialogue")
model = GPT2LMHeadModel.from_pretrained("ethzanalytics/ai-msgbot-gpt2-XL-dialogue", pad_token_id=tokenizer.eos_token_id)
script_speaker_name = "person alpha"
script_responder_name = "person beta"
global convo
convo = ""
def output(prompt, output_length):
global convo
if prompt.split(" ")[0].strip()=="clear_convo()":
convo = ""
prompt = prompt.split("clear_convo()")[1].strip()
sentence = convo + '\n' + script_speaker_name + ': ' + prompt + '\n' + script_responder_name + ': '
input_ids = tokenizer.encode(sentence, return_tensors='pt')
# generate text until the output length (which includes the context length) reaches 50
output = model.generate(input_ids, max_new_tokens=output_length, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
convo = tokenizer.decode(output[0], skip_special_tokens=True)
return convo
convo = ''
iface = gr.Interface(fn=output, inputs=["text", Slider(minimum=1.0, maximum=1000.0, step=1.0, default=50.0, label="Output Length")], outputs="text")
iface.launch()