import os import gradio as gr from text_generation import Client # HF-hosted endpoint for testing purposes (requires an HF API token) API_TOKEN = os.environ.get("API_TOKEN", None) CURRENT_CLIENT = Client("https://afrts4trc759c6eq.us-east-1.aws.endpoints.huggingface.cloud/generate_stream", timeout=120, headers={ "Accept": "application/json", "Authorization": f"Bearer {API_TOKEN}", "Content-Type": "application/json"} ) DEFAULT_HEADER = os.environ.get("HEADER", "") DEFAULT_USER_NAME = os.environ.get("USER_NAME", "user") DEFAULT_ASSISTANT_NAME = os.environ.get("ASSISTANT_NAME", "assistant") DEFAULT_SEPARATOR = os.environ.get("SEPARATOR", "<|im_end|>") PROMPT_TEMPLATE = "<|im_start|>{user_name}\n{query}{separator}\n<|im_start|>{assistant_name}\n{response}" repo = None def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() return total_inputs def has_no_history(chatbot, history): return not chatbot and not history def generate( user_message, chatbot, history, temperature, top_p, max_new_tokens, repetition_penalty, header, user_name, assistant_name, separator ): # Don't return meaningless message when the input is empty if not user_message: print("Empty input") history.append(user_message) past_messages = [] for data in chatbot: user_data, model_data = data past_messages.extend( [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] ) print(past_messages) if len(past_messages) < 1: prompt = header + PROMPT_TEMPLATE.format(user_name=user_name, query=user_message, assistant_name=assistant_name, response="", separator=separator) else: prompt = header for i in range(0, len(past_messages), 2): intermediate_prompt = PROMPT_TEMPLATE.format(user_name=user_name, query=past_messages[i]["content"], assistant_name=assistant_name, response=past_messages[i + 1]["content"], separator=separator) # print(prompt, separator, intermediate_prompt) prompt = prompt + intermediate_prompt + separator + "\n" # print(prompt) prompt = prompt + PROMPT_TEMPLATE.format(user_name=user_name, query=user_message, assistant_name=assistant_name, response="", separator=separator) temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=40, repetition_penalty=repetition_penalty, do_sample=True, truncate=1024, # seed=42, # stop_sequences=[user_name, DEFAULT_SEPARATOR] stop_sequences=[DEFAULT_SEPARATOR] ) # print(prompt) stream = CURRENT_CLIENT.generate_stream( prompt, **generate_kwargs, ) output = "" for idx, response in enumerate(stream): # print(response.token) if response.token.text == '': pass # print(response.token.text) # break if response.token.special: continue output += response.token.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] # chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] yield chat, history, user_message, "" return chat, history, user_message, "" def clear_chat(): return [], [] title = """

CroissantLLMChat Playground 🥐

""" custom_css = """ #banner-image { display: block; margin-left: auto; margin-right: auto; } #chat-message { font-size: 14px; min-height: 300px; } """ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: gr.HTML(title) with gr.Row(): with gr.Column(): gr.Markdown( """ Demo platform for 🥐 CroissantLLMChat. Model is of small size and can hallucinate and generate incorrect or even toxic content. """ ) with gr.Row(): with gr.Group(): output = gr.Markdown() chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") with gr.Row(): with gr.Column(scale=3): user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") with gr.Row(): send_button = gr.Button("Send", elem_id="send-btn", visible=True) clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True) with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): temperature = gr.Slider( label="Temperature", value=0.5, minimum=0.1, maximum=1.0, step=0.1, interactive=True, info="Higher values produce more diverse outputs", ) top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.9, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) max_new_tokens = gr.Slider( label="Max new tokens", value=512, minimum=0, maximum=1024, step=4, interactive=True, info="The maximum numbers of new tokens", ) repetition_penalty = gr.Slider( label="Repetition Penalty", value=1.2, minimum=0.0, maximum=10, step=0.1, interactive=True, info="The parameter for repetition penalty. 1.0 means no penalty.", ) with gr.Accordion(label="Prompt", open=False, elem_id="prompt-accordion"): header = gr.Textbox( label="Header instructions", value=DEFAULT_HEADER, interactive=True, info="Instructions given to the assistant at the beginning of the prompt", ) user_name = gr.Textbox( label="User name", value=DEFAULT_USER_NAME, interactive=True, info="Name to be given to the user in the prompt", ) assistant_name = gr.Textbox( label="Assistant name", value=DEFAULT_ASSISTANT_NAME, interactive=True, info="Name to be given to the assistant in the prompt", ) separator = gr.Textbox( label="Separator", value=DEFAULT_SEPARATOR, interactive=True, info="Character to be used when the speaker changes in the prompt", ) history = gr.State([]) last_user_message = gr.State("") user_message.submit( generate, inputs=[ user_message, chatbot, history, temperature, top_p, max_new_tokens, repetition_penalty, header, user_name, assistant_name, separator ], outputs=[chatbot, history, last_user_message, user_message], ) send_button.click( generate, inputs=[ user_message, chatbot, history, temperature, top_p, max_new_tokens, repetition_penalty, header, user_name, assistant_name, separator ], outputs=[chatbot, history, last_user_message, user_message], ) clear_chat_button.click(clear_chat, outputs=[chatbot, history]) demo.queue().launch()