import json import gradio as gr import os import requests # We get the token and the models API url hf_token = os.getenv("HF_TOKEN") llama_7b = os.getenv("API_URL_LLAMA_7") llama_13b = os.getenv("API_URL_LLAMA_13") zephyr_7b = os.getenv("API_URL_ZEPHYR_7") headers = { 'Content-Type': 'application/json', } """ Chat Function """ def chat(message, chatbot, model= llama_13b, system_prompt = "", temperature = 0.9, max_new_tokens = 256, top_p = 0.6, repetition_penalty = 1.0 ): # Write the system prompt if system_prompt != "": input_prompt = f"[INST] <>\n{system_prompt}\n<>\n\n " else: input_prompt = f"[INST] " temperature = float(temperature) # We check that temperature is not less than 1e-2 if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) for interaction in chatbot: input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " [INST] " input_prompt = input_prompt + str(message) + " [/INST] " data = { "inputs": input_prompt, "parameters": { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty, "do_sample": True, }, } response = requests.post(model, headers=headers, data=json.dumps(data), auth=("hf", hf_token), stream=True) partial_message = "" for line in response.iter_lines(): if line: # filter out keep-alive new lines # Decode from bytes to string decoded_line = line.decode('utf-8') # Remove 'data:' prefix if decoded_line.startswith('data:'): json_line = decoded_line[5:] # Exclude the first 5 characters ('data:') else: gr.Warning(f"This line does not start with 'data:': {decoded_line}") continue # Load as JSON try: json_obj = json.loads(json_line) if 'token' in json_obj: partial_message = partial_message + json_obj['token']['text'] yield partial_message elif 'error' in json_obj: yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.' else: gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}") except json.JSONDecodeError: gr.Warning(f"This line is not valid JSON: {json_line}") continue except KeyError as e: gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}") continue additional_inputs=[ gr.Dropdown(choices=["llama_7b", "llama_13b", "zephyr_7b"], label="Model", info="Which model do you want to use?"), gr.Textbox("", label="Optional system prompt"), gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=4096, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.6, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) ] title = "Find the password 🔒" description = "In this game prototype, your goal is to discuss with the intercom to find the correct password" chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False) chat_interface = gr.ChatInterface(chat, title=title, description=description, textbox=gr.Textbox(), chatbot=chatbot, additional_inputs=additional_inputs) # Gradio Demo with gr.Blocks() as demo: chat_interface.render() demo.launch(debug=True)