ofirzaf's picture
Update app.py
56dfe04 verified
import json
import os
import requests
import gradio as gr
CONTROLLER_URL = os.environ["CTRL_URL"]
DEFAULT_TEMPERATURE = os.environ.get("DEFAULT_TEMPERATURE", 1.0)
DEFAULT_TOPP = os.environ.get("DEFAULT_TOPP", 1.0)
DEFAULT_REP_PENELTY = os.environ.get("DEFAULT_REP_PENELTY", 1.0)
DEFAULT_MAX_NEW = os.environ.get("DEFAULT_MAX_NEW", 128)
WORKER_ERROR_MSG = "MODEL WORKER NOT FOUND. PLEASE REFRESH THIS PAGE."
WORKER_API_TIMEOUT = os.environ.get('WORKER_TIMEOUT', 100)
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
FIM_INDICATOR = "<FILL_HERE>"
notice_and_use_md = """
### Notice
- All the models in this demo run on 4th Generation Intel® Xeon® (Sapphire Rapids) utilizing AMX operations and mixed precision inference
- This demo is based on [Intel® Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch), [FastChat](https://github.com/lm-sys/FastChat/tree/main) and [BigCode - Playground](https://huggingface.co/spaces/bigcode/bigcode-playground)
### Terms of use
- By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It can produce factually incorrect output, and should not be relied on to produce factually accurate information. The service only provides limited safety measures and may generate lewd, biased or otherwise offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
- This app and its [supporting model](https://huggingface.co/bigcode) are provided for demonstration purposes; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](hf.co/bigcode)
### License
- Any use or sharing of this demo constitues your acceptance of the BigCode [OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) License Agreement and the use restrictions included within. Please contact us if you find any potential violation.
"""
theme = gr.themes.Soft(
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
def get_model_list():
ret = requests.post(CONTROLLER_URL + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(CONTROLLER_URL + "/list_models")
return ret.json()["models"]
def model_worker_stream_iter(
model_name,
prompt,
*,
do_sample,
temperature,
repetition_penalty,
top_p,
max_new_tokens,
assist,
):
# Make requests
gen_params = {
"model": model_name,
"prompt": prompt,
"do_sample": do_sample,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"assist": assist,
"stop_token_ids": [0], # FIXME should be decided by the model worker
}
# logger.info(f"==== request ====\n{gen_params}")
# Stream output
response = requests.post(
CONTROLLER_URL + "/worker_generate_stream",
# headers=headers,
json=gen_params,
stream=True,
timeout=WORKER_API_TIMEOUT,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
yield data
def check_assist(version):
return version is not None and "+AG" in version
def checkbox_update(version):
assist = check_assist(version)
return gr.update(value=assist, interactive=assist)
def generate(
prompt, temperature=1.0, max_new_tokens=256, top_p=1.0, repetition_penalty=1.0, assistant=False, version="StarCoder",
):
if len(prompt) == 0:
idle_label = gr.update(label="Idle")
yield None, idle_label
return None, idle_label
do_sample = True
temperature = float(temperature)
if temperature < 1e-2:
do_sample = False
top_p = float(top_p)
fim_mode = False
if FIM_INDICATOR in prompt:
fim_mode = True
try:
prefix, suffix = prompt.split(FIM_INDICATOR)
except:
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
error_label = gr.update(label="Error!")
# Query worker address
ret = requests.post(
CONTROLLER_URL + "/get_worker_address", json={"model": version}
)
if ret.json()["address"] == "":
output = WORKER_ERROR_MSG
yield output, error_label
return output, error_label
else:
yield None, gr.update(label="Waiting for server...")
stream = model_worker_stream_iter(
version,
prompt,
do_sample=do_sample,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
max_new_tokens=max_new_tokens,
assist=assistant,
)
if fim_mode:
output = prefix
else:
output = prompt
previous_token = ""
generating_label = gr.update(label="Generating...")
for response in stream:
if response["finish_reason"] is not None:
done_label = gr.update(label=f'Generation time: {float(response["generation_time"]):.2f} sec')
yield output, done_label
return output, done_label
if response["error_code"] == 0:
text = response["text"]
if text == "<|endoftext|>":
if fim_mode:
output += suffix
else:
return output, error_label
else:
output += text
previous_token = text
yield output, generating_label
else:
output = response["text"] + f"\n\n(error_code:{response['error_code']})"
yield output, error_label
return output, error_label
return output, error_label
examples = [
"""from typing import List
def has_close_elements(numbers: List[float], threshold: float) -> bool:
\"\"\" Check if in given list of numbers, are any two numbers closer to each other than
given threshold.
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
False
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
\"\"\"""",
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
# "def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results",
]
css = ".generating {visibility: hidden}"
monospace_css = """
#q-input textarea {
font-family: monospace, 'Consolas', Courier, monospace;
}
"""
h1_css = """
h1 {
text-align: center;
display:block;
}
"""
css += monospace_css + h1_css
description = """
<h1>Intel Lab's ⭐StarCoder Demo💫</h1>
This demo performs inference using the 4th Generation Intel® Xeon® (Sapphire Rapids) and includes the following models. For additional information, visit the <a href="https://huggingface.co/blog/intel-starcoder-quantization">StarCoder blog</a>.
1. Q8-StarCoder-15B+AG - StarCoder-15B quantized to 8bit with Assisted Generation
2. Q4-StarCoder-15B - StarCoder-15B quantized to 4bit
<b>Please note:</b> These models are not designed for instruction purposes.
"""
models = get_model_list()
model = None
if len(models) > 0:
if 'Q8-StarCoder-15B+AG' in models:
model = 'Q8-StarCoder-15B+AG'
else:
model = models[0]
with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
with gr.Column():
gr.Markdown(description)
with gr.Row():
version = gr.Dropdown(
models,
value=model,
label="Model",
show_label=False,
info="Choose a model from the list",
)
with gr.Column():
instruction = gr.Textbox(
placeholder="Enter your code here",
lines=5,
label="Input",
show_label=True,
elem_id="q-input",
)
with gr.Row():
submit = gr.Button("Generate", variant="primary", scale=85)
assistant = gr.Checkbox(
value=check_assist(model),
label="Assisted Generation",
interactive=check_assist(model),
scale=15,
)
output = gr.Code(elem_id="q-output", lines=30, show_label=True, label="")
with gr.Accordion("Advanced settings", open=False):
with gr.Row():
column_1, column_2 = gr.Column(), gr.Column()
with column_1:
temperature = gr.Slider(
label="Temperature",
value=DEFAULT_TEMPERATURE,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=DEFAULT_MAX_NEW,
minimum=0,
maximum=384,
step=32,
interactive=True,
info="The maximum numbers of new tokens",
)
with column_2:
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=DEFAULT_TOPP,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=DEFAULT_REP_PENELTY,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
outputs=[output],
)
gr.Markdown(notice_and_use_md)
gr.on(
triggers=[submit.click, instruction.submit],
fn=generate,
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, assistant, version],
outputs=[output, output],
concurrency_limit=10,
)
version.change(checkbox_update, [version], [assistant])
demo.queue().launch()