chansung's picture
Update app.py
13272b6
raw
history blame
No virus
12.3 kB
import os
import re
import json
import copy
import gradio as gr
from llama2 import GradioLLaMA2ChatPPManager
from llama2 import gen_text
from styles import MODEL_SELECTION_CSS
from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
from templates import templates
from constants import DEFAULT_GLOBAL_CTX
from pingpong import PingPong
from pingpong.context import CtxLastWindowStrategy
TOKEN = os.getenv('HF_TOKEN')
MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf'
def build_prompts(ppmanager, global_context, win_size=3):
dummy_ppm = copy.deepcopy(ppmanager)
dummy_ppm.ctx = global_context
lws = CtxLastWindowStrategy(win_size)
return lws(dummy_ppm)
ex_file = open("examples.txt", "r")
examples = ex_file.read().split("\n")
ex_btns = []
chl_file = open("channels.txt", "r")
channels = chl_file.read().split("\n")
channel_btns = []
def get_placeholders(text):
"""Returns all substrings in between <placeholder> and </placeholder>."""
pattern = r"\[([^\]]*)\]"
matches = re.findall(pattern, text)
return matches
def fill_up_placeholders(txt):
placeholders = get_placeholders(txt)
highlighted_txt = txt
return (
gr.update(
visible=True,
value=highlighted_txt
),
gr.update(
visible=True if len(placeholders) >= 1 else False,
placeholder=placeholders[0] if len(placeholders) >= 1 else ""
),
gr.update(
visible=True if len(placeholders) >= 2 else False,
placeholder=placeholders[1] if len(placeholders) >= 2 else ""
),
gr.update(
visible=True if len(placeholders) >= 3 else False,
placeholder=placeholders[2] if len(placeholders) >= 3 else ""
),
"" if len(placeholders) >= 1 else txt
)
async def chat_stream(
idx, local_data, instruction_txtbox, chat_state,
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
):
res = [
chat_state["ppmanager_type"].from_json(json.dumps(ppm))
for ppm in local_data
]
ppm = res[idx]
ppm.add_pingpong(
PingPong(instruction_txtbox, "")
)
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
async for result in gen_text(
prompt, hf_model=MODEL_ID, hf_token=TOKEN,
parameters={
'max_new_tokens': res_mnts,
'do_sample': res_sample,
'return_full_text': False,
'temperature': res_temp,
'top_k': res_topk,
'repetition_penalty': res_rpen
}
):
ppm.append_pong(result)
yield ppm.build_uis(), str(res)
def channel_num(btn_title):
choice = 0
for idx, channel in enumerate(channels):
if channel == btn_title:
choice = idx
return choice
def set_chatbot(btn, ld, state):
choice = channel_num(btn)
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
empty = len(res[choice].pingpongs) == 0
return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty))
def set_example(btn):
return btn, gr.update(visible=False)
def get_final_template(
txt, placeholder_txt1, placeholder_txt2, placeholder_txt3
):
placeholders = get_placeholders(txt)
example_prompt = txt
if len(placeholders) >= 1:
if placeholder_txt1 != "":
example_prompt = example_prompt.replace(f"[{placeholders[0]}]", placeholder_txt1)
if len(placeholders) >= 2:
if placeholder_txt2 != "":
example_prompt = example_prompt.replace(f"[{placeholders[1]}]", placeholder_txt2)
if len(placeholders) >= 3:
if placeholder_txt3 != "":
example_prompt = example_prompt.replace(f"[{placeholders[2]}]", placeholder_txt3)
return (
example_prompt,
"",
"",
""
)
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
with gr.Column() as chat_view:
idx = gr.State(0)
chat_state = gr.State({
"ppmanager_type": GradioLLaMA2ChatPPManager
})
local_data = gr.JSON({}, visible=False)
with gr.Row():
with gr.Column(scale=1, min_width=180):
gr.Markdown("GradioChat", elem_id="left-top")
with gr.Column(elem_id="left-pane"):
chat_back_btn = gr.Button("Back", elem_id="chat-back-btn")
with gr.Accordion("Histories", elem_id="chat-history-accordion", open=True):
channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"]))
for channel in channels[1:]:
channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"]))
with gr.Column(scale=8, elem_id="right-pane"):
with gr.Column(
elem_id="initial-popup", visible=False
) as example_block:
with gr.Row(scale=1):
with gr.Column(elem_id="initial-popup-left-pane"):
gr.Markdown("GradioChat", elem_id="initial-popup-title")
gr.Markdown("Making the community's best AI chat models available to everyone.")
with gr.Column(elem_id="initial-popup-right-pane"):
gr.Markdown("Chat UI is now open sourced on Hugging Face Hub")
gr.Markdown("check out the [β†— repository](https://huggingface.co/spaces/chansung/test-multi-conv)")
with gr.Column(scale=1):
gr.Markdown("Examples")
with gr.Row():
for example in examples:
ex_btns.append(gr.Button(example, elem_classes=["example-btn"]))
with gr.Column(elem_id="aux-btns-popup", visible=True):
with gr.Row():
stop = gr.Button("Stop", elem_classes=["aux-btn"])
regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"])
clean = gr.Button("Clean", elem_classes=["aux-btn"])
with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False):
context_inspector = gr.Textbox(
"",
elem_id="aux-viewer-inspector",
label="",
lines=30,
max_lines=50,
)
chatbot = gr.Chatbot(elem_id='chatbot', label="LLaMA2-70B-Chat")
instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt")
with gr.Accordion("Example Templates", open=False):
template_txt = gr.Textbox(visible=False)
template_md = gr.Markdown(label="Chosen Template", visible=False, elem_classes="template-txt")
with gr.Row():
placeholder_txt1 = gr.Textbox(label="placeholder #1", visible=False, interactive=True)
placeholder_txt2 = gr.Textbox(label="placeholder #2", visible=False, interactive=True)
placeholder_txt3 = gr.Textbox(label="placeholder #3", visible=False, interactive=True)
for template in templates:
with gr.Tab(template['title']):
gr.Examples(
template['template'],
inputs=[template_txt],
outputs=[template_md, placeholder_txt1, placeholder_txt2, placeholder_txt3, instruction_txtbox],
run_on_click=True,
fn=fill_up_placeholders,
)
with gr.Accordion("Control Panel", open=False) as control_panel:
with gr.Column():
with gr.Column():
gr.Markdown("#### Global context")
with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=True):
global_context = gr.Textbox(
DEFAULT_GLOBAL_CTX,
lines=5,
max_lines=10,
interactive=True,
elem_id="global-context"
)
# gr.Markdown("#### Internet search")
# with gr.Row():
# internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
# serper_api_key = gr.Textbox(
# value= "" if args.serper_api_key is None else args.serper_api_key,
# placeholder="Get one by visiting serper.dev",
# label="Serper api key"
# )
gr.Markdown("#### GenConfig for **response** text generation")
with gr.Row():
res_temp = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="temp", interactive=True)
res_topk = gr.Slider(20, 1000, 50, step=1, label="top_k", interactive=True)
res_rpen = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="rep_penalty", interactive=True)
res_mnts = gr.Slider(64, 8192, 512, step=1, label="new_tokens", interactive=True)
res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
with gr.Column():
gr.Markdown("#### Context managements")
with gr.Row():
ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True)
instruction_txtbox.submit(
chat_stream,
[idx, local_data, instruction_txtbox, chat_state,
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv],
[chatbot, local_data]
)
for btn in channel_btns:
btn.click(
set_chatbot,
[btn, local_data, chat_state],
[chatbot, idx, example_block, regenerate]
).then(
None, btn, None,
_js=UPDATE_LEFT_BTNS_STATE
)
for btn in ex_btns:
btn.click(
set_example,
[btn],
[instruction_txtbox, example_block]
)
placeholder_txt1.change(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[template_md],
show_progress=False,
_js=UPDATE_PLACEHOLDERS,
fn=None
)
placeholder_txt2.change(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[template_md],
show_progress=False,
_js=UPDATE_PLACEHOLDERS,
fn=None
)
placeholder_txt3.change(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[template_md],
show_progress=False,
_js=UPDATE_PLACEHOLDERS,
fn=None
)
placeholder_txt1.submit(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3],
fn=get_final_template
)
placeholder_txt2.submit(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3],
fn=get_final_template
)
placeholder_txt3.submit(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3],
fn=get_final_template
)
demo.load(
None,
inputs=None,
outputs=[chatbot, local_data],
_js=GET_LOCAL_STORAGE,
)
demo.launch()