xuqinyang's picture
Update app.py
0f120a4
raw
history blame
No virus
6.5 kB
from typing import Iterator
import gradio as gr
from model import run
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 4000
DESCRIPTION = """
# Baichuan-13B-Chat
"""
LICENSE=""
def clear_and_save_textbox(message: str) -> tuple[str, str]:
return '', message
def display_input(message: str,
history: list[tuple[str, str]]) -> list[tuple[str, str]]:
history.append((message, ''))
return history
def delete_prev_fn(
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ''
return history, message or ''
def generate(
message: str,
history_with_input: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Iterator[list[tuple[str, str]]]:
history = history_with_input[:-1]
generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
for response in generator:
yield history + [(message, response)]
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 8192, 1, 0.95, 50)
for x in generator:
pass
return '', x
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
a = 1
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value='Duplicate Space for private use',
elem_id='duplicate-button')
with gr.Group():
chatbot = gr.Chatbot(label='Chatbot')
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder='Type a message...',
scale=10,
)
submit_button = gr.Button('Submit',
variant='primary',
scale=1,
min_width=0)
with gr.Row():
retry_button = gr.Button('🔄 Retry', variant='secondary')
undo_button = gr.Button('↩️ Undo', variant='secondary')
clear_button = gr.Button('🗑️ Clear', variant='secondary')
saved_input = gr.State()
with gr.Accordion(label='Advanced options', open=False):
system_prompt = ""
max_new_tokens = 1024
temperature = gr.Slider(
label='Temperature',
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.3,
)
top_p = gr.Slider(
label='Top-p (nucleus sampling)',
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.85,
)
top_k = gr.Slider(
label='Top-k',
minimum=1,
maximum=1000,
step=1,
value=5,
)
gr.Examples(
examples=[
'用中文回答,When is the best time to visit Beijing, and do you have any suggestions for me?',
'用中文回答,特朗普是谁?',
],
inputs=textbox,
outputs=[textbox, chatbot],
fn=process_example,
cache_examples=True,
)
gr.Markdown(LICENSE)
textbox.submit(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
api_name=False,
queue=False,
).success(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
],
outputs=chatbot,
api_name=False,
)
button_event_preprocess = submit_button.click(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
api_name=False,
queue=False,
).success(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
],
outputs=chatbot,
api_name=False,
)
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
],
outputs=chatbot,
api_name=False,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=textbox,
api_name=False,
queue=False,
)
clear_button.click(
fn=lambda: ([], ''),
outputs=[chatbot, saved_input],
queue=False,
api_name=False,
)
demo.queue(max_size=20).launch()