|
import json |
|
import os |
|
import requests |
|
|
|
import gradio as gr |
|
|
|
|
|
CONTROLLER_URL = os.environ["CTRL_URL"] |
|
DEFAULT_TEMPERATURE = os.environ.get("DEFAULT_TEMPERATURE", 1.0) |
|
DEFAULT_TOPP = os.environ.get("DEFAULT_TOPP", 1.0) |
|
DEFAULT_REP_PENELTY = os.environ.get("DEFAULT_REP_PENELTY", 1.0) |
|
DEFAULT_MAX_NEW = os.environ.get("DEFAULT_MAX_NEW", 128) |
|
|
|
WORKER_ERROR_MSG = "MODEL WORKER NOT FOUND. PLEASE REFRESH THIS PAGE." |
|
WORKER_API_TIMEOUT = os.environ.get('WORKER_TIMEOUT', 100) |
|
|
|
|
|
FIM_PREFIX = "<fim_prefix>" |
|
FIM_MIDDLE = "<fim_middle>" |
|
FIM_SUFFIX = "<fim_suffix>" |
|
|
|
FIM_INDICATOR = "<FILL_HERE>" |
|
|
|
|
|
notice_and_use_md = """ |
|
### Notice |
|
- All the models in this demo run on 4th Generation Intel® Xeon® (Sapphire Rapids) utilizing AMX operations and mixed precision inference |
|
- This demo is based on [Intel® Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch), [FastChat](https://github.com/lm-sys/FastChat/tree/main) and [BigCode - Playground](https://huggingface.co/spaces/bigcode/bigcode-playground) |
|
### Terms of use |
|
- By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It can produce factually incorrect output, and should not be relied on to produce factually accurate information. The service only provides limited safety measures and may generate lewd, biased or otherwise offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. |
|
- This app and its [supporting model](https://huggingface.co/bigcode) are provided for demonstration purposes; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](hf.co/bigcode) |
|
### License |
|
- Any use or sharing of this demo constitues your acceptance of the BigCode [OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) License Agreement and the use restrictions included within. Please contact us if you find any potential violation. |
|
""" |
|
|
|
|
|
theme = gr.themes.Soft( |
|
font=[ |
|
gr.themes.GoogleFont("Open Sans"), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
], |
|
) |
|
|
|
|
|
def get_model_list(): |
|
ret = requests.post(CONTROLLER_URL + "/refresh_all_workers") |
|
assert ret.status_code == 200 |
|
ret = requests.post(CONTROLLER_URL + "/list_models") |
|
return ret.json()["models"] |
|
|
|
|
|
def model_worker_stream_iter( |
|
model_name, |
|
prompt, |
|
*, |
|
do_sample, |
|
temperature, |
|
repetition_penalty, |
|
top_p, |
|
max_new_tokens, |
|
assist, |
|
): |
|
|
|
gen_params = { |
|
"model": model_name, |
|
"prompt": prompt, |
|
"do_sample": do_sample, |
|
"temperature": temperature, |
|
"repetition_penalty": repetition_penalty, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
"assist": assist, |
|
"stop_token_ids": [0], |
|
} |
|
|
|
|
|
|
|
response = requests.post( |
|
CONTROLLER_URL + "/worker_generate_stream", |
|
|
|
json=gen_params, |
|
stream=True, |
|
timeout=WORKER_API_TIMEOUT, |
|
) |
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
|
if chunk: |
|
data = json.loads(chunk.decode()) |
|
yield data |
|
|
|
|
|
def check_assist(version): |
|
return version is not None and "+AG" in version |
|
|
|
|
|
def checkbox_update(version): |
|
assist = check_assist(version) |
|
return gr.update(value=assist, interactive=assist) |
|
|
|
|
|
def generate( |
|
prompt, temperature=1.0, max_new_tokens=256, top_p=1.0, repetition_penalty=1.0, assistant=False, version="StarCoder", |
|
): |
|
if len(prompt) == 0: |
|
idle_label = gr.update(label="Idle") |
|
yield None, idle_label |
|
return None, idle_label |
|
do_sample = True |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
do_sample = False |
|
top_p = float(top_p) |
|
fim_mode = False |
|
|
|
if FIM_INDICATOR in prompt: |
|
fim_mode = True |
|
try: |
|
prefix, suffix = prompt.split(FIM_INDICATOR) |
|
except: |
|
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!") |
|
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" |
|
error_label = gr.update(label="Error!") |
|
|
|
ret = requests.post( |
|
CONTROLLER_URL + "/get_worker_address", json={"model": version} |
|
) |
|
if ret.json()["address"] == "": |
|
output = WORKER_ERROR_MSG |
|
yield output, error_label |
|
return output, error_label |
|
else: |
|
yield None, gr.update(label="Waiting for server...") |
|
|
|
stream = model_worker_stream_iter( |
|
version, |
|
prompt, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
top_p=top_p, |
|
max_new_tokens=max_new_tokens, |
|
assist=assistant, |
|
) |
|
|
|
if fim_mode: |
|
output = prefix |
|
else: |
|
output = prompt |
|
|
|
previous_token = "" |
|
generating_label = gr.update(label="Generating...") |
|
for response in stream: |
|
if response["finish_reason"] is not None: |
|
done_label = gr.update(label=f'Generation time: {float(response["generation_time"]):.2f} sec') |
|
yield output, done_label |
|
return output, done_label |
|
if response["error_code"] == 0: |
|
text = response["text"] |
|
if text == "<|endoftext|>": |
|
if fim_mode: |
|
output += suffix |
|
else: |
|
return output, error_label |
|
else: |
|
output += text |
|
previous_token = text |
|
yield output, generating_label |
|
else: |
|
output = response["text"] + f"\n\n(error_code:{response['error_code']})" |
|
yield output, error_label |
|
return output, error_label |
|
return output, error_label |
|
|
|
|
|
|
|
examples = [ |
|
"""from typing import List |
|
|
|
|
|
def has_close_elements(numbers: List[float], threshold: float) -> bool: |
|
\"\"\" Check if in given list of numbers, are any two numbers closer to each other than |
|
given threshold. |
|
>>> has_close_elements([1.0, 2.0, 3.0], 0.5) |
|
False |
|
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) |
|
True |
|
\"\"\"""", |
|
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score", |
|
|
|
] |
|
|
|
css = ".generating {visibility: hidden}" |
|
|
|
monospace_css = """ |
|
#q-input textarea { |
|
font-family: monospace, 'Consolas', Courier, monospace; |
|
} |
|
""" |
|
|
|
h1_css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
css += monospace_css + h1_css |
|
|
|
description = """ |
|
<h1>Intel Lab's ⭐StarCoder Demo💫</h1> |
|
|
|
This demo performs inference using the 4th Generation Intel® Xeon® (Sapphire Rapids) and includes the following models. For additional information, visit the <a href="https://huggingface.co/blog/intel-starcoder-quantization">StarCoder blog</a>. |
|
1. Q8-StarCoder-15B+AG - StarCoder-15B quantized to 8bit with Assisted Generation |
|
2. Q4-StarCoder-15B - StarCoder-15B quantized to 4bit |
|
|
|
<b>Please note:</b> These models are not designed for instruction purposes. |
|
""" |
|
|
|
models = get_model_list() |
|
model = None |
|
if len(models) > 0: |
|
if 'Q8-StarCoder-15B+AG' in models: |
|
model = 'Q8-StarCoder-15B+AG' |
|
else: |
|
model = models[0] |
|
with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo: |
|
with gr.Column(): |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
version = gr.Dropdown( |
|
models, |
|
value=model, |
|
label="Model", |
|
show_label=False, |
|
info="Choose a model from the list", |
|
) |
|
with gr.Column(): |
|
instruction = gr.Textbox( |
|
placeholder="Enter your code here", |
|
lines=5, |
|
label="Input", |
|
show_label=True, |
|
elem_id="q-input", |
|
) |
|
with gr.Row(): |
|
submit = gr.Button("Generate", variant="primary", scale=85) |
|
assistant = gr.Checkbox( |
|
value=check_assist(model), |
|
label="Assisted Generation", |
|
interactive=check_assist(model), |
|
scale=15, |
|
) |
|
output = gr.Code(elem_id="q-output", lines=30, show_label=True, label="") |
|
with gr.Accordion("Advanced settings", open=False): |
|
with gr.Row(): |
|
column_1, column_2 = gr.Column(), gr.Column() |
|
with column_1: |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
value=DEFAULT_TEMPERATURE, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
) |
|
max_new_tokens = gr.Slider( |
|
label="Max new tokens", |
|
value=DEFAULT_MAX_NEW, |
|
minimum=0, |
|
maximum=384, |
|
step=32, |
|
interactive=True, |
|
info="The maximum numbers of new tokens", |
|
) |
|
with column_2: |
|
top_p = gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
value=DEFAULT_TOPP, |
|
minimum=0.0, |
|
maximum=1, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values sample more low-probability tokens", |
|
) |
|
repetition_penalty = gr.Slider( |
|
label="Repetition penalty", |
|
value=DEFAULT_REP_PENELTY, |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Penalize repeated tokens", |
|
) |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[instruction], |
|
cache_examples=False, |
|
outputs=[output], |
|
) |
|
gr.Markdown(notice_and_use_md) |
|
|
|
gr.on( |
|
triggers=[submit.click, instruction.submit], |
|
fn=generate, |
|
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, assistant, version], |
|
outputs=[output, output], |
|
concurrency_limit=10, |
|
) |
|
version.change(checkbox_update, [version], [assistant]) |
|
demo.queue().launch() |