FindThePassword / app.py
ThomasSimonini's picture
Update app.py
a4eb928
import json
import gradio as gr
import os
import requests
# We get the token and the models API url
hf_token = os.getenv("HF_TOKEN")
llama_7b = os.getenv("API_URL_LLAMA_7")
llama_13b = os.getenv("API_URL_LLAMA_13")
zephyr_7b = os.getenv("API_URL_ZEPHYR_7")
headers = {
'Content-Type': 'application/json',
}
"""
Chat Function
"""
def chat(message,
chatbot,
model= llama_13b,
system_prompt = "",
temperature = 0.9,
max_new_tokens = 256,
top_p = 0.6,
repetition_penalty = 1.0
):
# Write the system prompt
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
# We check that temperature is not less than 1e-2
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
data = {
"inputs": input_prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": True,
},
}
print("MODEL" + model)
if model == "zephyr_7b":
model = zephyr_7b
elif model == "llama_7b":
model = llama_7b
elif model == "llama_13b":
model = llama_13b
response = requests.post(model, headers=headers, data=json.dumps(data), auth=("hf", hf_token), stream=True)
partial_message = ""
for line in response.iter_lines():
if line: # filter out keep-alive new lines
# Decode from bytes to string
decoded_line = line.decode('utf-8')
# Remove 'data:' prefix
if decoded_line.startswith('data:'):
json_line = decoded_line[5:] # Exclude the first 5 characters ('data:')
else:
gr.Warning(f"This line does not start with 'data:': {decoded_line}")
continue
# Load as JSON
try:
json_obj = json.loads(json_line)
if 'token' in json_obj:
partial_message = partial_message + json_obj['token']['text']
return partial_message #yield
elif 'error' in json_obj:
return json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
# yield
else:
gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")
except json.JSONDecodeError:
gr.Warning(f"This line is not valid JSON: {json_line}")
continue
except KeyError as e:
gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
continue
additional_inputs=[
gr.Dropdown(choices=["llama_7b", "llama_13b", "zephyr_7b"], label="Model", info="Which model do you want to use?"),
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
title = "Find the password πŸ”’"
description = "In this game prototype, your goal is to discuss with the intercom to find the correct password"
chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
chat_interface = gr.ChatInterface(chat,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot,
additional_inputs=additional_inputs)
# Gradio Demo
with gr.Blocks() as demo:
chat_interface.render()
demo.launch(debug=True)