LLaMA-7B / app.py
chansung's picture
Update app.py
fdc528c
raw history blame
No virus
3.23 kB
import os
import time
import torch
import gradio as gr
from strings import TITLE, ABSTRACT, EXAMPLES
from gen import get_pretrained_models, get_output, setup_model_parallel
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"
local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)
history = []
def chat(
user_input,
include_input,
truncate,
top_p,
temperature,
max_gen_len,
state_chatbot
):
bot_response = get_output(
generator=generator,
prompt=user_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p)[0]
# remove the first phrase identical to user prompt
if not include_input:
bot_response = bot_response[len(user_input):]
bot_response = bot_response.replace("\n", "<br>")
# trip the last phrase
if truncate:
try:
bot_response = bot_response[:bot_response.rfind(".")+1]
except:
pass
history.append({
"role": "user",
"content": user_input
})
history.append({
"role": "system",
"content": bot_response
})
state_chatbot = state_chatbot + [(user_input, None)]
response = ""
for word in bot_response.split(" "):
time.sleep(0.1)
response += word + " "
current_pair = (user_input, response)
state_chatbot[-1] = current_pair
yield state_chatbot, state_chatbot
def reset_textbox():
return gr.update(value='')
with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;}
#chatbot {height: 400px; overflow: auto;}""") as demo:
state_chatbot = gr.State([])
with gr.Column(elem_id='col_container'):
gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
with gr.Accordion("Example prompts", open=False):
example_str = "\n"
for example in EXAMPLES:
example_str += f"- {example}\n"
gr.Markdown(example_str)
chatbot = gr.Chatbot(elem_id='chatbot')
textbox = gr.Textbox(placeholder="Enter a prompt")
with gr.Accordion("Parameters", open=False):
include_input = gr.Checkbox(value=True, label="Do you want to include the input in the generated text?")
truncate = gr.Checkbox(value=True, label="Truncate the unfinished last words?")
max_gen_len = gr.Slider(minimum=20, maximum=512, value=256, step=1, interactive=True, label="Max Genenration Length",)
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
textbox.submit(
chat,
[textbox, include_input, truncate, top_p, temperature, max_gen_len, state_chatbot],
[state_chatbot, chatbot]
)
textbox.submit(reset_textbox, [], [textbox])
demo.queue(api_open=False).launch()