Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from text_generation import Client, errors | |
from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
API_URL = " https://api-inference.huggingface.co/models/BigCode/octocoder" | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=gr.themes.sizes.radius_sm, | |
font=[ | |
gr.themes.GoogleFont("Open Sans"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
) | |
client = Client( | |
API_URL, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
def generate(query: str, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, ): | |
if query.endswith("."): | |
prompt = f"Question: {query}\n\nAnswer:" | |
else: | |
prompt = f"Question: {query}.\n\nAnswer:" | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
try: | |
stream = client.generate_stream(prompt, **generate_kwargs) | |
output = "" | |
previous_token = "" | |
for response in stream: | |
if response.token.text == "<|endoftext|>": | |
return output | |
else: | |
output += response.token.text | |
previous_token = response.token.text | |
yield output | |
return output | |
except errors.UnknownError as e: | |
print(f"Error: {e}") | |
message = "Please wait for a while, The OctoCoder model is currently loading... π" | |
output = "" | |
for item in message.split(" "): | |
if item == "π": | |
output += "π" | |
return output | |
else: | |
output += f"{item} " | |
yield output | |
return output | |
def process_example(**krwags): | |
for x in generate(**krwags): | |
pass | |
return x | |
css = ".generating {visibility: hidden}" | |
monospace_css = """ | |
#q-input textarea { | |
font-family: monospace, 'Consolas', Courier, monospace; | |
} | |
""" | |
css += share_btn_css + monospace_css | |
description = """ | |
<div style="text-align: center;"> | |
<center><img src='https://raw.githubusercontent.com/bigcode-project/octopack/31f3320f098703c7910e43492c39366eeea68d83/banner.png' width='70%'/></center> | |
<br> | |
<h1><u> OctoCoder Demo </u></h1> | |
</div> | |
<br> | |
<div style="text-align: center;"> | |
<p>This is a demo to demonstrate the capabilities of <a href="https://huggingface.co/bigcode/octocoder">OctoCoder</a> model by showing how it can be used to generate code by following the instructions provided in the input.</p> | |
<p><strong>OctoCoder</strong> is an instruction tuned model with 15.5B parameters created by finetuning StarCoder on CommitPackFT & OASST</p> | |
</div> | |
""" | |
disclaimer = """β οΈ<b>Any use or sharing of this demo constitues your acceptance of the BigCode [OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) License Agreement and the use restrictions included within.</b>\ | |
<br>**Intended Use**: this app and its [supporting model](https://huggingface.co/bigcode) are provided for demonstration purposes; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](https://huggingface.co/bigcode)""" | |
examples = [ | |
['Please write a function in Python that performs bubble sort.', 256], | |
['''Explain the following piece of code | |
def count_unique(s): | |
s = s.lower() | |
s_split = list(s) | |
valid_chars = [char for char in s_split if char.isalpha() or char == " "] | |
valid_sentence = "".join(valid_chars) | |
uniques = set(valid_sentence.split(" ")) | |
return len(uniques)''', 512], | |
[ | |
'Write an efficient Python function that takes a given text and returns its Morse code equivalent without using any third party library', | |
512], | |
['Write a html and css code to render a clock', 8000], | |
] | |
with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo: | |
with gr.Column(): | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Settings", open=True): | |
with gr.Row(): | |
column_1, column_2 = gr.Column(), gr.Column() | |
with column_1: | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.2, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=8192, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
) | |
with column_2: | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox( | |
placeholder="Enter your query here", | |
lines=5, | |
label="Input", | |
elem_id="q-input", | |
) | |
submit = gr.Button("Generate", variant="primary") | |
output = gr.Code(elem_id="q-output", lines=30, label="Output") | |
gr.Markdown(disclaimer) | |
with gr.Group(elem_id="share-btn-container"): | |
community_icon = gr.HTML(community_icon_html, visible=True) | |
loading_icon = gr.HTML(loading_icon_html, visible=True) | |
share_button = gr.Button( | |
"Share to community", elem_id="share-btn", visible=True | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[instruction, max_new_tokens], | |
cache_examples=False, | |
fn=process_example, | |
outputs=[output], | |
) | |
submit.click( | |
generate, | |
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], | |
outputs=[output], | |
) | |
share_button.click(None, [], [], _js=share_js) | |
demo.queue(concurrency_count=16).launch(debug=True) | |