import os import re import gradio as gr from text_generation import Client from dialogues import DialogueTemplate model2endpoint = { "starchat-beta": os.environ.get("API_URL", None), } model_names = list(model2endpoint.keys()) def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() return total_inputs def wrap_html_code(text): pattern = r"<.*?>" matches = re.findall(pattern, text) if len(matches) > 0: return f"```{text}```" else: return text def has_no_history(chatbot, history): return not chatbot and not history def generate( user_message, chatbot, history, ): system_message = "Below is a conversation between a human user and a helpful AI coding assistant." temperature = 0.2 top_k = 50 top_p = 0.95 max_new_tokens = 1024 repetition_penalty = 1.2 client = Client( model2endpoint["starchat-beta"] ) # Don't return meaningless message when the input is empty if not user_message: print("Empty input") history.append(user_message) past_messages = [] for data in chatbot: user_data, model_data = data past_messages.extend( [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] ) if len(past_messages) < 1: dialogue_template = DialogueTemplate( system=system_message, messages=[{"role": "user", "content": user_message}] ) prompt = dialogue_template.get_inference_prompt() else: dialogue_template = DialogueTemplate( system=system_message, messages=past_messages + [{"role": "user", "content": user_message}] ) prompt = dialogue_template.get_inference_prompt() generate_kwargs = { "temperature": temperature, "top_k": top_k, "top_p": top_p, "max_new_tokens": max_new_tokens, } temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, truncate=1000, seed=42, stop_sequences=["<|end|>"], ) stream = client.generate_stream( prompt, **generate_kwargs, ) output = "" for idx, response in enumerate(stream): if response.token.special: continue output += response.token.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [ (wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip())) for i in range(0, len(history) - 1, 2) ] # chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] yield chat, history, user_message, "" return chat, history, user_message, "" examples = [ "How can I write a Python function to generate the nth Fibonacci number?", "How do I get the current date using shell commands? Explain how it works.", "What's the meaning of life?", "Write a function in Javascript to reverse words in a given string.", "Give the following data {'Name':['Tom', 'Brad', 'Kyle', 'Jerry'], 'Age':[20, 21, 19, 18], 'Height' : [6.1, 5.9, 6.0, 6.1]}. Can you plot one graph with two subplots as columns. The first is a bar graph showing the height of each person. The second is a bargraph showing the age of each person? Draw the graph in seaborn talk mode.", "Create a regex to extract dates from logs", "How to decode JSON into a typescript object", "Write a list into a jsonlines file and save locally", ] def clear_chat(): return [], [] def process_example(args): for [x, y] in generate(args): pass return [x, y] title = """

⭐ StarChat Saturdays 💬

Asistente de IA para estudiantes de Inteligencia Artificial

""" info = """
¡Tu privacidad es nuestra prioridad! Toda la información compartida en esta conversación se elimina automáticamente una vez que salgas del chat.
""" custom_css = """ #banner-image { display: block; margin-left: auto; margin-right: auto; } #chat-message { font-size: 14px; min-height: 300px; } """ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: gr.HTML(title) with gr.Row(): with gr.Box(): output = gr.Markdown() chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") with gr.Row(): with gr.Column(scale=3): user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") with gr.Row(): send_button = gr.Button("Send", elem_id="send-btn", visible=True) history = gr.State([]) last_user_message = gr.State("") user_message.submit( generate, inputs=[ user_message, chatbot, history, ], outputs=[chatbot, history, last_user_message, user_message], ) send_button.click( generate, inputs=[ user_message, chatbot, history, ], outputs=[chatbot, history, last_user_message, user_message], ) gr.HTML(info) demo.queue(concurrency_count=16).launch()