FindThePassword / app.py
Thomas Simonini
Update demo
b9f4649
raw
history blame
No virus
4.55 kB
import json
import gradio as gr
import os
import requests
hf_token = os.getenv("HF_TOKEN")
api_url_7b = os.getenv("API_URL_LLAMA_7")
api_url_13b = os.getenv("API_URL_LLAMA_13")
api_url_70b = os.getenv("API_URL_LLAMA_70")
headers = {
'Content-Type': 'application/json',
}
title = "Find the password πŸ”’"
description = "In this game prototype, your goal is to discuss with the intercom to find the correct password"
def predict(message,
chatbot,
system_prompt = "",
temperature = 0.9,
max_new_tokens = 256,
top_p = 0.6,
repetition_penalty = 1.0,
model=api_url_70b):
# Write the system prompt
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
# We check that temperature is not less than 1e-2
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
data = {
"inputs": input_prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": True,
},
}
response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=("hf, hf_token"), stream=True)
partial_message = ""
for line in response.iter_lines():
if line: # filter out keep-alive new lines
# Decode from bytes to string
decoded_line = line.decode('utf-8')
# Remove 'data:' prefix
if decoded_line.startswith('data:'):
json_line = decoded_line[5:] # Exclude the first 5 characters ('data:')
else:
gr.Warning(f"This line does not start with 'data:': {decoded_line}")
continue
# Load as JSON
try:
json_obj = json.loads(json_line)
if 'token' in json_obj:
partial_message = partial_message + json_obj['token']['text']
yield partial_message
elif 'error' in json_obj:
yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
else:
gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")
except json.JSONDecodeError:
gr.Warning(f"This line is not valid JSON: {json_line}")
continue
except KeyError as e:
gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
continue
additional_inputs=[
gr.Dropdown(["api_url_7b", "api_url_13b", "api_url_70b"], label="Model", info="Which model to use?"),
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
chat_interface_stream = gr.ChatInterface(predict,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot,
examples=examples,
cache_examples=True,
additional_inputs=additional_inputs,
model = model)
# Gradio Demo
with gr.Blocks() as demo:
chat_interface.render()
demo.launch(debug=True)