Spaces:
Running
Running
| import os | |
| import re | |
| import gradio as gr | |
| from text_generation import Client | |
| from dialogues import DialogueTemplate | |
| model2endpoint = { | |
| "starchat-beta": os.environ.get("API_URL", None), | |
| } | |
| model_names = list(model2endpoint.keys()) | |
| def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): | |
| past = [] | |
| for data in chatbot: | |
| user_data, model_data = data | |
| if not user_data.startswith(user_name): | |
| user_data = user_name + user_data | |
| if not model_data.startswith(sep + assistant_name): | |
| model_data = sep + assistant_name + model_data | |
| past.append(user_data + model_data.rstrip() + sep) | |
| if not inputs.startswith(user_name): | |
| inputs = user_name + inputs | |
| total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
| return total_inputs | |
| def wrap_html_code(text): | |
| pattern = r"<.*?>" | |
| matches = re.findall(pattern, text) | |
| if len(matches) > 0: | |
| return f"```{text}```" | |
| else: | |
| return text | |
| def has_no_history(chatbot, history): | |
| return not chatbot and not history | |
| def generate( | |
| user_message, | |
| chatbot, | |
| history, | |
| ): | |
| system_message = "Below is a conversation between a human user and a helpful AI coding assistant." | |
| temperature = 0.2 | |
| top_k = 50 | |
| top_p = 0.95 | |
| max_new_tokens = 1024 | |
| repetition_penalty = 1.2 | |
| client = Client( | |
| model2endpoint["starchat-beta"] | |
| ) | |
| # Don't return meaningless message when the input is empty | |
| if not user_message: | |
| print("Empty input") | |
| history.append(user_message) | |
| past_messages = [] | |
| for data in chatbot: | |
| user_data, model_data = data | |
| past_messages.extend( | |
| [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] | |
| ) | |
| if len(past_messages) < 1: | |
| dialogue_template = DialogueTemplate( | |
| system=system_message, messages=[{"role": "user", "content": user_message}] | |
| ) | |
| prompt = dialogue_template.get_inference_prompt() | |
| else: | |
| dialogue_template = DialogueTemplate( | |
| system=system_message, messages=past_messages + [{"role": "user", "content": user_message}] | |
| ) | |
| prompt = dialogue_template.get_inference_prompt() | |
| generate_kwargs = { | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "max_new_tokens": max_new_tokens, | |
| } | |
| temperature = float(temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| generate_kwargs = dict( | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| truncate=1000, | |
| seed=42, | |
| stop_sequences=["<|end|>"], | |
| ) | |
| stream = client.generate_stream( | |
| prompt, | |
| **generate_kwargs, | |
| ) | |
| output = "" | |
| for idx, response in enumerate(stream): | |
| if response.token.special: | |
| continue | |
| output += response.token.text | |
| if idx == 0: | |
| history.append(" " + output) | |
| else: | |
| history[-1] = output | |
| chat = [ | |
| (wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip())) | |
| for i in range(0, len(history) - 1, 2) | |
| ] | |
| # chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] | |
| yield chat, history, user_message, "" | |
| return chat, history, user_message, "" | |
| examples = [ | |
| "How can I write a Python function to generate the nth Fibonacci number?", | |
| "How do I get the current date using shell commands? Explain how it works.", | |
| "What's the meaning of life?", | |
| "Write a function in Javascript to reverse words in a given string.", | |
| "Give the following data {'Name':['Tom', 'Brad', 'Kyle', 'Jerry'], 'Age':[20, 21, 19, 18], 'Height' : [6.1, 5.9, 6.0, 6.1]}. Can you plot one graph with two subplots as columns. The first is a bar graph showing the height of each person. The second is a bargraph showing the age of each person? Draw the graph in seaborn talk mode.", | |
| "Create a regex to extract dates from logs", | |
| "How to decode JSON into a typescript object", | |
| "Write a list into a jsonlines file and save locally", | |
| ] | |
| def clear_chat(): | |
| return [], [] | |
| def process_example(args): | |
| for [x, y] in generate(args): | |
| pass | |
| return [x, y] | |
| title = """<h2 align="center">⭐ StarChat Saturdays 💬</h2> | |
| <h4 align="center">Asistente de IA para estudiantes de Inteligencia Artificial</h4> | |
| """ | |
| info = """<h5 align="center">¡Tu privacidad es nuestra prioridad! Toda la información compartida en esta conversación se elimina automáticamente una vez que salgas del chat.</h5>""" | |
| custom_css = """ | |
| #banner-image { | |
| display: block; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| #chat-message { | |
| font-size: 14px; | |
| min-height: 300px; | |
| } | |
| """ | |
| with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Box(): | |
| output = gr.Markdown() | |
| chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") | |
| with gr.Row(): | |
| send_button = gr.Button("Send", elem_id="send-btn", visible=True) | |
| history = gr.State([]) | |
| last_user_message = gr.State("") | |
| user_message.submit( | |
| generate, | |
| inputs=[ | |
| user_message, | |
| chatbot, | |
| history, | |
| ], | |
| outputs=[chatbot, history, last_user_message, user_message], | |
| ) | |
| send_button.click( | |
| generate, | |
| inputs=[ | |
| user_message, | |
| chatbot, | |
| history, | |
| ], | |
| outputs=[chatbot, history, last_user_message, user_message], | |
| ) | |
| gr.HTML(info) | |
| demo.queue(concurrency_count=16).launch() |