luna-playground / app.py
lvwerra's picture
lvwerra HF staff
Update app.py
549ab4b
raw
history blame
No virus
6.37 kB
import json
import os
import shutil
import gradio as gr
from huggingface_hub import Repository
from text_generation import Client
from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
FIM_INDICATOR = "<FILL_HERE>"
FORMATS = """## Model formats
The model is pretrained on code and in addition to the pure code data it is formatted with special tokens. E.g. prefixes specifying the source of the file or special tokens separating code from a commit message. See below:
### Prefixes
Any combination of the three following prefixes can be found in pure code files:
```
<reponame>REPONAME<filename>FILENAME<gh_stars>STARS\ncode<|endoftext|>
```
STARS can be one of: 0, 1-10, 10-100, 100-1000, 1000+
### Commits
The commits data is formatted as follows:
```
<commit_before>code<commit_msg>text<commit_after>code<|endoftext|>
```
### Jupyter structure
Jupyter notebooks were both trained in form of Python scripts as well as the following structured format:
```
<start_jupyter><jupyter_text>text<jupyter_code>code<jupyter_output>output<jupyter_text>
```
### Issues
We also trained on GitHub issues using the following formatting:
```
<issue_start><issue_comment>text<issue_comment>...<issue_closed>
```
### Fill-in-the-middle
Fill in the middle requires rearranging the model inputs. The playground does this for you - all you need is to specify where to fill:
```
code before<FILL_HERE>code after
```
"""
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
client = Client(
API_URL,
#headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
def generate(prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
fim_mode = False
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
if FIM_INDICATOR in prompt:
fim_mode = True
try:
prefix, suffix = prompt.split(FIM_INDICATOR)
except:
ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
stream = client.generate_stream(prompt, **generate_kwargs)
if fim_mode:
output = prefix
else:
output = prompt
for response in stream:
if fim_mode and response.token.text =="<|endoftext|>":
output += (suffix + "\n" + response.token.text)
else:
output += response.token.text
yield output
return output
examples = [
"def print_hello_world():",
"def fibonacci(n):",
"class TransformerDecoder(nn.Module):",
"class ComplexNumbers:"
]
def process_example(args):
for x in generate(args):
pass
return x
css = ".generating {visibility: hidden}" + share_btn_css
with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
with gr.Column():
gr.Markdown(
"""\
# BigCode - Playground
_Note:_ this is an internal playground - please do not share. The deployment can also change and thus the space not work as we continue development.\
"""
)
with gr.Row():
with gr.Column(scale=3):
instruction = gr.Textbox(placeholder="Enter your prompt here", label="Prompt", elem_id="q-input")
submit = gr.Button("Generate", variant="primary")
output = gr.Code(elem_id="q-output")
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=True)
loading_icon = gr.HTML(loading_icon_html, visible=True)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_example,
outputs=[output],
)
gr.Markdown(FORMATS)
with gr.Column(scale=1):
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=8192,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], outputs=[output])
instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], outputs=[output])
share_button.click(None, [], [], _js=share_js)
demo.queue(concurrency_count=16).launch(debug=True)