Spaces:
Sleeping
Sleeping
File size: 4,746 Bytes
779a991 052f9cb 779a991 ed2a31a 779a991 052f9cb 779a991 6c70efd 779a991 6c70efd 9cc1eea 779a991 ef87e14 7c5b6ea 7f0b753 7c5b6ea aa0ae43 7c5b6ea d2704e9 779a991 ef87e14 779a991 a4eb928 779a991 a4eb928 779a991 aadd9ad 779a991 7169471 656f8a7 779a991 052f9cb 779a991 916f14e 779a991 b9f4649 5a59c8e 779a991 b9f4649 779a991 b9f4649 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import json
import gradio as gr
import os
import requests
# We get the token and the models API url
hf_token = os.getenv("HF_TOKEN")
llama_7b = os.getenv("API_URL_LLAMA_7")
llama_13b = os.getenv("API_URL_LLAMA_13")
zephyr_7b = os.getenv("API_URL_ZEPHYR_7")
headers = {
'Content-Type': 'application/json',
}
"""
Chat Function
"""
def chat(message,
chatbot,
model= llama_13b,
system_prompt = "",
temperature = 0.9,
max_new_tokens = 256,
top_p = 0.6,
repetition_penalty = 1.0
):
# Write the system prompt
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
# We check that temperature is not less than 1e-2
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
data = {
"inputs": input_prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": True,
},
}
print("MODEL" + model)
if model == "zephyr_7b":
model = zephyr_7b
elif model == "llama_7b":
model = llama_7b
elif model == "llama_13b":
model = llama_13b
response = requests.post(model, headers=headers, data=json.dumps(data), auth=("hf", hf_token), stream=True)
partial_message = ""
for line in response.iter_lines():
if line: # filter out keep-alive new lines
# Decode from bytes to string
decoded_line = line.decode('utf-8')
# Remove 'data:' prefix
if decoded_line.startswith('data:'):
json_line = decoded_line[5:] # Exclude the first 5 characters ('data:')
else:
gr.Warning(f"This line does not start with 'data:': {decoded_line}")
continue
# Load as JSON
try:
json_obj = json.loads(json_line)
if 'token' in json_obj:
partial_message = partial_message + json_obj['token']['text']
return partial_message #yield
elif 'error' in json_obj:
return json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
# yield
else:
gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")
except json.JSONDecodeError:
gr.Warning(f"This line is not valid JSON: {json_line}")
continue
except KeyError as e:
gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
continue
additional_inputs=[
gr.Dropdown(choices=["llama_7b", "llama_13b", "zephyr_7b"], label="Model", info="Which model do you want to use?"),
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
title = "Find the password 🔒"
description = "In this game prototype, your goal is to discuss with the intercom to find the correct password"
chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
chat_interface = gr.ChatInterface(chat,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot,
additional_inputs=additional_inputs)
# Gradio Demo
with gr.Blocks() as demo:
chat_interface.render()
demo.launch(debug=True) |