Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import logging as log | |
| from typing import Generator | |
| import gradio as gr | |
| from gradio.themes.utils import sizes | |
| from text_generation import Client | |
| from src.request import StarCoderRequest, StarCoderRequestConfig | |
| from src.utils import ( | |
| get_file_as_string, | |
| get_sections, | |
| get_url_from_env_or_default_path, | |
| preview | |
| ) | |
| from constants import ( | |
| FIM_MIDDLE, | |
| FIM_PREFIX, | |
| FIM_SUFFIX, | |
| END_OF_TEXT, | |
| MIN_TEMPERATURE, | |
| ) | |
| from settings import ( | |
| DEFAULT_PORT, | |
| DEFAULT_STARCODER_API_PATH, | |
| DEFAULT_STARCODER_BASE_API_PATH, | |
| ) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # Gracefully exit the app if the HF_TOKEN is not set, | |
| # printing to system `errout` the error (instead of raising an exception) | |
| # and the expected behavior | |
| if not HF_TOKEN: | |
| ERR_MSG = """ | |
| Please set the HF_TOKEN environment variable with your Hugging Face API token. | |
| You can get one by signing up at https://huggingface.co/join and then visiting | |
| https://huggingface.co/settings/tokens.""" | |
| print(ERR_MSG, file=sys.stderr) | |
| # gr.errors.GradioError(ERR_MSG) | |
| # gr.close_all(verbose=False) | |
| sys.exit(1) | |
| API_URL_STAR = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH) | |
| API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH) | |
| preview("StarCoder Model URL", API_URL_STAR) | |
| preview("StarCoderBase Model URL", API_URL_BASE) | |
| preview("HF Token", HF_TOKEN, ofuscate=True) | |
| _styles = get_file_as_string("styles.css") | |
| _script = get_file_as_string("community-btn.js") | |
| _sharing_icon_svg = get_file_as_string("community-icon.svg") | |
| _loading_icon_svg = get_file_as_string("loading-icon.svg") | |
| # Loads the whole content of the ./README.md file | |
| # slicing/unpacking its different sections into their proper variables | |
| readme_file_content = get_file_as_string("README.md", path='./') | |
| ( | |
| manifest, | |
| description, | |
| disclaimer, | |
| formats, | |
| ) = get_sections(readme_file_content, "---", up_to=4) | |
| theme = gr.themes.Monochrome( | |
| primary_hue="indigo", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| radius_size=sizes.radius_sm, | |
| font=[ | |
| gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif", | |
| ], | |
| text_size=sizes.text_lg, | |
| ) | |
| HEADERS = { | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| } | |
| client_star = Client(API_URL_STAR, headers=HEADERS) | |
| client_base = Client(API_URL_BASE, headers=HEADERS) | |
| def get_tokens_collector(request: StarCoderRequest) -> Generator[str, None, None]: | |
| model_client = client_star if request.settings.version == "StarCoder" else client_base | |
| stream = model_client.generate_stream(request.prompt, **request.settings.kwargs()) | |
| for response in stream: | |
| # print(response.token.id, response.token.text) | |
| # if token.text != END_OF_TEXT: | |
| if response.token.id != 0: | |
| yield response.token.text | |
| def get_tokens_accumulator(request: StarCoderRequest) -> Generator[str, None, None]: | |
| # start with the prefix (if in fim_mode) | |
| output = request.prefix if request.fim_mode else request.prompt | |
| for token in get_tokens_collector(request=request): | |
| output += token | |
| yield output | |
| # after the last token, append the suffix (if in fim_mode) | |
| if request.fim_mode: | |
| output += request.suffix | |
| yield output | |
| # Append an extra line at the end | |
| yield output + '\n' | |
| def get_tokens_linker(request: StarCoderRequest) -> str: | |
| return "".join(list(get_tokens_collector(request))) | |
| def generate( | |
| prompt: str, | |
| temperature = 0.9, | |
| max_new_tokens = 256, | |
| top_p = 0.95, | |
| repetition_penalty = 1.0, | |
| version = "StarCoder", | |
| ) -> Generator[str, None, None]: | |
| request = StarCoderRequest( | |
| prompt=prompt, | |
| settings=StarCoderRequestConfig( | |
| version=version, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| ) | |
| yield from get_tokens_accumulator(request) | |
| def process_example( | |
| prompt: str, | |
| temperature = 0.9, | |
| max_new_tokens = 256, | |
| top_p = 0.95, | |
| repetition_penalty = 1.0, | |
| version = "StarCoder", | |
| ) -> Generator[str, None, None]: | |
| request = StarCoderRequest( | |
| prompt=prompt, | |
| settings=StarCoderRequestConfig( | |
| version=version, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| ) | |
| yield from get_tokens_linker(request) | |
| # todo: move it into the README too | |
| examples = [ | |
| "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score", | |
| "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {", | |
| "def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results", | |
| ] | |
| with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: | |
| with gr.Column(): | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| instruction = gr.Textbox( | |
| placeholder="Enter your code here", | |
| label="Code", | |
| elem_id="q-input", | |
| ) | |
| submit = gr.Button("Generate", variant="primary") | |
| output = gr.Code(elem_id="q-output", lines=30) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Accordion("Advanced settings", open=False): | |
| with gr.Row(): | |
| column_1, column_2 = gr.Column(), gr.Column() | |
| with column_1: | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| value=0.2, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| value=256, | |
| minimum=0, | |
| maximum=8192, | |
| step=64, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ) | |
| with column_2: | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.90, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| with gr.Column(): | |
| version = gr.Dropdown( | |
| ["StarCoderBase", "StarCoder"], | |
| value="StarCoder", | |
| label="Version", | |
| info="", | |
| ) | |
| gr.Markdown(disclaimer) | |
| with gr.Group(elem_id="share-btn-container"): | |
| community_icon = gr.HTML(_sharing_icon_svg, visible=True) | |
| loading_icon = gr.HTML(_loading_icon_svg, visible=True) | |
| share_button = gr.Button( | |
| "Share to community", elem_id="share-btn", visible=True | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[instruction], | |
| cache_examples=False, | |
| fn=process_example, | |
| outputs=[output], | |
| ) | |
| gr.Markdown(formats) | |
| submit.click( | |
| generate, | |
| inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version], | |
| outputs=[output], | |
| # preprocess=False, | |
| max_batch_size=8, | |
| show_progress=True | |
| ) | |
| share_button.click(None, [], [], _js=_script) | |
| demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT) | |