|
import os |
|
import time |
|
import torch |
|
import gradio as gr |
|
|
|
from strings import TITLE, ABSTRACT, EXAMPLES |
|
from gen import get_pretrained_models, get_output |
|
|
|
generator = get_pretrained_models("13B", "tokenizer") |
|
|
|
history = [] |
|
|
|
def chat( |
|
user_input, |
|
include_input, |
|
truncate, |
|
top_p, |
|
temperature, |
|
max_gen_len, |
|
state_chatbot |
|
): |
|
bot_response = get_output( |
|
generator=generator, |
|
prompt=user_input, |
|
max_gen_len=max_gen_len, |
|
temperature=temperature, |
|
top_p=top_p)[0] |
|
|
|
|
|
if not include_input: |
|
bot_response = bot_response[len(user_input):] |
|
bot_response = bot_response.replace("\n", "<br>") |
|
|
|
|
|
if truncate: |
|
try: |
|
bot_response = bot_response[:bot_response.rfind(".")+1] |
|
except: |
|
pass |
|
|
|
history.append({ |
|
"role": "user", |
|
"content": user_input |
|
}) |
|
history.append({ |
|
"role": "system", |
|
"content": bot_response |
|
}) |
|
|
|
state_chatbot = state_chatbot + [(user_input, None)] |
|
|
|
response = "" |
|
for word in bot_response.split(" "): |
|
time.sleep(0.1) |
|
response += word + " " |
|
current_pair = (user_input, response) |
|
state_chatbot[-1] = current_pair |
|
yield state_chatbot, state_chatbot |
|
|
|
def reset_textbox(): |
|
return gr.update(value='') |
|
|
|
with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;} |
|
#chatbot {height: 400px; overflow: auto;}""") as demo: |
|
|
|
state_chatbot = gr.State([]) |
|
|
|
with gr.Column(elem_id='col_container'): |
|
gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}") |
|
|
|
with gr.Accordion("Example prompts", open=False): |
|
example_str = "\n" |
|
for example in EXAMPLES: |
|
example_str += f"- {example}\n" |
|
|
|
gr.Markdown(example_str) |
|
|
|
chatbot = gr.Chatbot(elem_id='chatbot') |
|
textbox = gr.Textbox(placeholder="Enter a prompt") |
|
|
|
with gr.Accordion("Parameters", open=False): |
|
include_input = gr.Checkbox(value=True, label="Do you want to include the input in the generated text?") |
|
truncate = gr.Checkbox(value=True, label="Truncate the unfinished last words?") |
|
|
|
max_gen_len = gr.Slider(minimum=20, maximum=512, value=256, step=1, interactive=True, label="Max Genenration Length",) |
|
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",) |
|
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",) |
|
|
|
textbox.submit( |
|
chat, |
|
[textbox, include_input, truncate, top_p, temperature, max_gen_len, state_chatbot], |
|
[state_chatbot, chatbot] |
|
) |
|
textbox.submit(reset_textbox, [], [textbox]) |
|
|
|
demo.queue(api_open=False).launch() |