Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
import os | |
import requests | |
# We get the token and the models API url | |
hf_token = os.getenv("HF_TOKEN") | |
llama_7b = os.getenv("API_URL_LLAMA_7") | |
llama_13b = os.getenv("API_URL_LLAMA_13") | |
zephyr_7b = os.getenv("API_URL_ZEPHYR_7") | |
headers = { | |
'Content-Type': 'application/json', | |
} | |
""" | |
Chat Function | |
""" | |
def chat(message, | |
chatbot, | |
model= llama_13b, | |
system_prompt = "", | |
temperature = 0.9, | |
max_new_tokens = 256, | |
top_p = 0.6, | |
repetition_penalty = 1.0 | |
): | |
# Write the system prompt | |
if system_prompt != "": | |
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " | |
else: | |
input_prompt = f"<s>[INST] " | |
temperature = float(temperature) | |
# We check that temperature is not less than 1e-2 | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
for interaction in chatbot: | |
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] " | |
input_prompt = input_prompt + str(message) + " [/INST] " | |
data = { | |
"inputs": input_prompt, | |
"parameters": { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"do_sample": True, | |
}, | |
} | |
print("MODEL" + model) | |
if model == "zephyr_7b": | |
model = zephyr_7b | |
elif model == "llama_7b": | |
model = llama_7b | |
elif model == "llama_13b": | |
model == llama_13b | |
response = requests.post(model, headers=headers, data=json.dumps(data), auth=("hf", hf_token), stream=True) | |
partial_message = "" | |
for line in response.iter_lines(): | |
if line: # filter out keep-alive new lines | |
# Decode from bytes to string | |
decoded_line = line.decode('utf-8') | |
# Remove 'data:' prefix | |
if decoded_line.startswith('data:'): | |
json_line = decoded_line[5:] # Exclude the first 5 characters ('data:') | |
else: | |
gr.Warning(f"This line does not start with 'data:': {decoded_line}") | |
continue | |
# Load as JSON | |
try: | |
json_obj = json.loads(json_line) | |
if 'token' in json_obj: | |
partial_message = partial_message + json_obj['token']['text'] | |
yield partial_message | |
elif 'error' in json_obj: | |
yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.' | |
else: | |
gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}") | |
except json.JSONDecodeError: | |
gr.Warning(f"This line is not valid JSON: {json_line}") | |
continue | |
except KeyError as e: | |
gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}") | |
continue | |
additional_inputs=[ | |
gr.Dropdown(choices=["llama_7b", "llama_13b", "zephyr_7b"], label="Model", info="Which model do you want to use?"), | |
gr.Textbox("", label="Optional system prompt"), | |
gr.Slider( | |
label="Temperature", | |
value=0.9, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
), | |
gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=4096, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.6, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
] | |
title = "Find the password π" | |
description = "In this game prototype, your goal is to discuss with the intercom to find the correct password" | |
chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False) | |
chat_interface = gr.ChatInterface(chat, | |
title=title, | |
description=description, | |
textbox=gr.Textbox(), | |
chatbot=chatbot, | |
additional_inputs=additional_inputs) | |
# Gradio Demo | |
with gr.Blocks() as demo: | |
chat_interface.render() | |
demo.launch(debug=True) |