import os import gradio as gr from gradio import interface, blocks import requests from text_generation import Client, InferenceAPIClient openchat_preprompt = ( "\n: Generate a button that says hello\n:\n" ) preprompt = "[REQUIRMENTS]:\n Only output in html syntax.\n Do not output a html file! \n Do not use the tag! \n DO NOT USE
tag, DO not output explanation. Do not use Natural Language, Only answer in html syntax, \n only output the html for the elements in my question, DO NOT USE HELLO WORLD!!!," prepromptTags = [ '
', '

', '

', '

', '

', '

', '

', '
', '', '', '', ] def get_client(model: str): if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B": return Client(os.getenv("OPENCHAT_API_URL")) return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None)) def get_usernames(model: str): """ Returns: (str, str, str, str): pre-prompt, username, bot name, separator """ if model == "OpenAssistant/oasst-sft-1-pythia-12b": return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>" if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B": return openchat_preprompt, ": ", ": ", "\n" return "", "User: ", "Assistant: ", "\n" def predict( model: str, inputs: str, typical_p: float, top_p: float, temperature: float, top_k: int, repetition_penalty: float, watermark: bool, chatbot, history, ): client = get_client(model) preprompt, user_name, assistant_name, sep = get_usernames(model) history.append(inputs) past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + \ "".join(prepromptTags) + "".join(past) + \ inputs + sep + assistant_name.rstrip() partial_words = "" if model == "OpenAssistant/oasst-sft-1-pythia-12b": iterator = client.generate_stream( total_inputs, typical_p=typical_p, truncate=1000, watermark=watermark, max_new_tokens=500, ) else: iterator = client.generate_stream( total_inputs, top_p=top_p if top_p < 1.0 else None, top_k=top_k, truncate=1000, repetition_penalty=repetition_penalty, watermark=watermark, temperature=temperature, max_new_tokens=500, stop_sequences=[user_name.rstrip(), assistant_name.rstrip()], ) for i, response in enumerate(iterator): if response.token.special: continue partial_words = partial_words + response.token.text if partial_words.endswith(user_name.rstrip()): partial_words = partial_words.rstrip(user_name.rstrip()) if partial_words.endswith(assistant_name.rstrip()): partial_words = partial_words.rstrip(assistant_name.rstrip()) if i == 0: history.append(" " + partial_words) elif response.token.text not in user_name: history[-1] = partial_words chat = [ (history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2) ] yield chat, history def reset_textbox(): return gr.update(value="") def radio_on_change( value: str, typical_p, top_p, top_k, temperature, repetition_penalty, watermark, ): if value == "OpenAssistant/oasst-sft-1-pythia-12b": typical_p = typical_p.update(value=0.2, visible=True) top_p = top_p.update(visible=False) top_k = top_k.update(visible=False) temperature = temperature.update(visible=False) repetition_penalty = repetition_penalty.update(visible=False) watermark = watermark.update(False) elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B": typical_p = typical_p.update(visible=False) top_p = top_p.update(value=0.25, visible=True) top_k = top_k.update(value=50, visible=True) temperature = temperature.update(value=0.6, visible=True) repetition_penalty = repetition_penalty.update( value=1.01, visible=True) watermark = watermark.update(False) else: typical_p = typical_p.update(visible=False) top_p = top_p.update(value=0.95, visible=True) top_k = top_k.update(value=4, visible=True) temperature = temperature.update(value=0.5, visible=True) repetition_penalty = repetition_penalty.update( value=1.03, visible=True) watermark = watermark.update(True) return ( typical_p, top_p, top_k, temperature, repetition_penalty, watermark, ) title = """

WAB-Assist

""" with gr.Blocks( css=""" #col_container {margin-left: auto; margin-right: auto;} #chatbot {height: 420px; overflow: auto; box-shadow: 0 0 10px rgba(0,0,0,0.2)} #userInput { box-shadow: 0 0 10px rgba(0,0,0,0.2);padding:0px;} #userInput span{display:none} #submit, #api {max-width: max-content;background: #313170;color: white;} """ ) as view: gr.HTML(title) gr.Markdown(visible=True) with gr.Column(elem_id="col_container"): model = gr.Radio( value="OpenAssistant/oasst-sft-1-pythia-12b", choices=[ "OpenAssistant/oasst-sft-1-pythia-12b", # "togethercomputer/GPT-NeoXT-Chat-Base-20B", # "google/flan-t5-xxl", # "google/flan-ul2", # "bigscience/bloom", # "bigscience/bloomz", # "EleutherAI/gpt-neox-20b", ], label="Model", interactive=False, visible=False ) chatbot = gr.Chatbot(elem_id="chatbot") with gr.Row(elem_id="row"): inputs = gr.Textbox( placeholder="hey!", label="", elem_id="userInput" ) buttonSend = gr.Button(value="send", elem_id="submit") buttonAPI = gr.Button(value="api", elem_id="api") state = gr.State([]) with gr.Accordion("Parameters", open=False, visible=False): typical_p = gr.Slider( minimum=-0, maximum=1.0, value=0.55, step=0.05, interactive=True, label="Typical P mass", ) top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.55, step=0.05, interactive=True, label="Top-p (nucleus sampling)", visible=True, ) temperature = gr.Slider( minimum=-0, maximum=5.0, value=3, step=0.1, interactive=True, label="Temperature", visible=True, ) top_k = gr.Slider( minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k", visible=True, ) repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=2, step=0.01, interactive=True, label="Repetition Penalty", visible=True, ) watermark = gr.Checkbox(value=False, label="Text watermarking") model.change( lambda value: radio_on_change( value, typical_p, top_p, top_k, temperature, repetition_penalty, watermark, ), inputs=model, outputs=[ typical_p, top_p, top_k, temperature, repetition_penalty, watermark, ], ) inputs.submit( predict, [ model, inputs, typical_p, top_p, temperature, top_k, repetition_penalty, watermark, chatbot, state, ], [chatbot, state], ) buttonSend.click( predict, [ model, inputs, typical_p, top_p, temperature, top_k, repetition_penalty, watermark, chatbot, state, ], [chatbot, state], ) buttonSend.click(reset_textbox, [], [inputs]) inputs.submit(reset_textbox, [], [inputs]) view.launch()