|
from typing import Iterator |
|
|
|
import gradio as gr |
|
|
|
|
|
from model import run |
|
|
|
DEFAULT_SYSTEM_PROMPT = "" |
|
MAX_MAX_NEW_TOKENS = 2048 |
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
MAX_INPUT_TOKEN_LENGTH = 4000 |
|
|
|
DESCRIPTION = """ |
|
# 玉刚六号改/yugangVI-Chat |
|
""" |
|
LICENSE="基于Baichuan-13B-Chat以及https://github.com/ouwei2013/baichuan13b.cpp" |
|
|
|
|
|
|
|
def clear_and_save_textbox(message: str) -> tuple[str, str]: |
|
return '', message |
|
|
|
|
|
def display_input(message: str, |
|
history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
|
history.append((message, '')) |
|
return history |
|
|
|
|
|
def delete_prev_fn( |
|
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
|
try: |
|
message, _ = history.pop() |
|
except IndexError: |
|
message = '' |
|
return history, message or '' |
|
|
|
|
|
def generate( |
|
message: str, |
|
history_with_input: list[tuple[str, str]], |
|
system_prompt: str, |
|
max_new_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
top_k: int, |
|
) -> Iterator[list[tuple[str, str]]]: |
|
|
|
history = history_with_input[:-1] |
|
generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k) |
|
for response in generator: |
|
yield history + [(message, response)] |
|
|
|
|
|
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]: |
|
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 8192, 1, 0.95, 50) |
|
for x in generator: |
|
pass |
|
return '', x |
|
|
|
|
|
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None: |
|
a = 1 |
|
|
|
|
|
with gr.Blocks(css='style.css') as demo: |
|
gr.Markdown(DESCRIPTION) |
|
gr.DuplicateButton(value='Duplicate Space for private use', |
|
elem_id='duplicate-button') |
|
|
|
with gr.Group(): |
|
chatbot = gr.Chatbot(label='Chatbot') |
|
with gr.Row(): |
|
textbox = gr.Textbox( |
|
container=False, |
|
show_label=False, |
|
placeholder='请输入/Type a message...', |
|
scale=10, |
|
) |
|
submit_button = gr.Button('提交/Submit', |
|
variant='primary', |
|
scale=1, |
|
min_width=0) |
|
with gr.Row(): |
|
retry_button = gr.Button('🔄 重来/Retry', variant='secondary') |
|
undo_button = gr.Button('↩️ 撤销/Undo', variant='secondary') |
|
clear_button = gr.Button('🗑️ 清除/Clear', variant='secondary') |
|
|
|
saved_input = gr.State() |
|
|
|
with gr.Accordion(label='进阶设置/Advanced options', open=False): |
|
system_prompt = gr.Textbox(label='预设引导词/System prompt', |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=6) |
|
max_new_tokens = gr.Slider( |
|
label='Max new tokens', |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=DEFAULT_MAX_NEW_TOKENS, |
|
) |
|
temperature = gr.Slider( |
|
label='情感温度/Temperature', |
|
minimum=0.1, |
|
maximum=4.0, |
|
step=0.1, |
|
value=0.3, |
|
) |
|
top_p = gr.Slider( |
|
label='Top-p (nucleus sampling)', |
|
minimum=0.05, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.85, |
|
) |
|
top_k = gr.Slider( |
|
label='Top-k', |
|
minimum=1, |
|
maximum=1000, |
|
step=1, |
|
value=5, |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
'中华人民共和国的首都是?', |
|
|
|
], |
|
inputs=textbox, |
|
outputs=[textbox, chatbot], |
|
fn=process_example, |
|
cache_examples=True, |
|
) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
textbox.submit( |
|
fn=clear_and_save_textbox, |
|
inputs=textbox, |
|
outputs=[textbox, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=check_input_token_length, |
|
inputs=[saved_input, chatbot, system_prompt], |
|
api_name=False, |
|
queue=False, |
|
).success( |
|
fn=generate, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=False, |
|
) |
|
|
|
button_event_preprocess = submit_button.click( |
|
fn=clear_and_save_textbox, |
|
inputs=textbox, |
|
outputs=[textbox, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=check_input_token_length, |
|
inputs=[saved_input, chatbot, system_prompt], |
|
api_name=False, |
|
queue=False, |
|
).success( |
|
fn=generate, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=False, |
|
) |
|
|
|
retry_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=generate, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=False, |
|
) |
|
|
|
undo_button.click( |
|
|
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=lambda x: x, |
|
inputs=[saved_input], |
|
outputs=textbox, |
|
api_name=False, |
|
queue=False, |
|
) |
|
|
|
clear_button.click( |
|
fn=lambda: ([], ''), |
|
outputs=[chatbot, saved_input], |
|
queue=False, |
|
api_name=False, |
|
) |
|
|
|
demo.queue(max_size=20).launch() |
|
|