File size: 3,901 Bytes
2fc3868
5e467f7
422c008
 
2fc3868
5e467f7
 
 
 
2fc3868
5e467f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
091a2b3
 
 
 
 
 
 
5e467f7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import os
from text_generation import Client


FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
FIM_INDICATOR = "<FILL_HERE>"

API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase"
HF_TOKEN = os.environ.get("HF_TOKEN", None)

theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("Open Sans"),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
)
css = ".generating {visibility: hidden}"

monospace_css = """
#q-input textarea {
    font-family: monospace, 'Consolas', Courier, monospace;
}
"""


css +=  monospace_css + ".gradio-container {color: black}"

description = """
<div style="text-align: center;">
    <p>This is a demo to generate code with <a href="https://huggingface.co/bigcode/starcoder" style='color: #e6b800;'>StarCoder</a></p>
</div>
"""

examples = [
    "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
    "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
    "def alternating(list1, list2):\n   results = []\n   for i in range(min(len(list1), len(list2))):\n       results.append(list1[i])\n       results.append(list2[i])\n   if len(list1) > len(list2):\n       <FILL_HERE>\n   else:\n       results.extend(list2[i+1:])\n   return results",
]

client_base = Client(
    API_URL_BASE, headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

def generate(
    prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoder",
):

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    fim_mode = False

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    if FIM_INDICATOR in prompt:
        fim_mode = True
        try:
            prefix, suffix = prompt.split(FIM_INDICATOR)
        except:
            raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
        prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"

    if version == "StarCoder":
        stream = client.generate_stream(prompt, **generate_kwargs)
    else:
        stream = client_base.generate_stream(prompt, **generate_kwargs)

    if fim_mode:
        output = prefix
    else:
        output = prompt

    previous_token = ""
    for response in stream:
        if response.token.text == "<|endoftext|>":
            if fim_mode:
                output += suffix
            else:
                return output
        else:
            output += response.token.text
        previous_token = response.token.text
        yield output
    return output
def process_example(args):
    for x in generate(args):
        pass
    return x

with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
    with gr.Column():
        gr.Markdown(description)
        with gr.Row():
            with gr.Column():
                instruction = gr.Textbox(
                    placeholder="Enter your code here",
                    label="Code",
                    elem_id="q-input",
                )
                submit = gr.Button("Generate", variant="primary")
                output = gr.Code(elem_id="q-output", lines=30)
        gr.Examples(
                examples=examples,
                inputs=[instruction],
                cache_examples=False,
                fn=process_example,
                outputs=[output],
            )

demo.queue(concurrency_count=16).launch(debug=True)