File size: 5,696 Bytes
fc0de4b
 
 
73f3f89
5b1d10e
fc0de4b
9b58071
5b1d10e
 
 
fc0de4b
 
b51f9d1
fc0de4b
 
5b1d10e
 
 
 
 
fc0de4b
5b1d10e
 
8b988f9
5b1d10e
 
 
 
 
 
1f37744
 
b51f9d1
5b1d10e
 
 
 
8423d7f
5b1d10e
 
 
 
 
 
8423d7f
5b1d10e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e3e943
5b1d10e
 
 
 
 
 
 
fc0de4b
 
7a6f200
 
fc0de4b
7a6f200
 
5b1d10e
 
 
ccb7c42
9b58071
5b1d10e
 
9b58071
 
 
8423d7f
9b58071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a6f200
5b1d10e
7a6f200
baca49b
7a6f200
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import os
import torch

description = """# <p style="text-align: center; color: white;"> 🎅 <span style='color: #ff75b3;'>SantaCoder:</span> Code Generation </p>
<span style='color: white;'>This is a demo to generate code with <a href="https://huggingface.co/bigcode/santacoder" style="color: #ff75b3;">SantaCoder</a>,
a 1.1B parameter model for code generation in Python, Java & JavaScript. The model can also do infilling, just specify where you would like the model to complete code
with the <span style='color: #ff75b3;'>&lt;FILL-HERE&gt;</span> token.</span>"""

token = os.environ["HUB_TOKEN"]
device="cuda:0"


FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"

GENERATION_TITLE= "<p style='font-size: 16px; color: white;'>Generated code:</p>"

tokenizer_fim = AutoTokenizer.from_pretrained("bigcode/santacoder", use_auth_token=token, padding_side="left")

tokenizer_fim.add_special_tokens({
  "additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
  "pad_token": EOD,
})

tokenizer = AutoTokenizer.from_pretrained("bigcode/christmas-models", use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained("bigcode/christmas-models", trust_remote_code=True, use_auth_token=token).to(device)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)

def post_processing(prompt, completion):
    completion = "<span style='color: #ff75b3;'>" + completion + "</span>"
    prompt = "<span style='color: #727cd6;'>" + prompt + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prompt}{completion}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def post_processing_fim(prefix, middle, suffix):
    prefix = "<span style='color: #727cd6;'>" + prefix + "</span>"
    middle = "<span style='color: #ff75b3;'>" + middle + "</span>"
    suffix = "<span style='color: #727cd6;'>" + suffix + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prefix}{middle}{suffix}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def fim_generation(prompt, max_new_tokens, temperature):
    prefix = prompt.split("<FILL-HERE>")[0]
    suffix = prompt.split("<FILL-HERE>")[1]
    [middle] = infill((prefix, suffix), max_new_tokens, temperature)
    return post_processing_fim(prefix, middle, suffix)

def extract_fim_part(s: str):
    # Find the index of 
    start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
    stop = s.find(EOD, start) or len(s)
    return s[start:stop]

def infill(prefix_suffix_tuples, max_new_tokens, temperature):
    if type(prefix_suffix_tuples) == tuple:
        prefix_suffix_tuples = [prefix_suffix_tuples]
        
    prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer_fim(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )
    # WARNING: cannot use skip_special_tokens, because it blows away the FIM special tokens.
    return [        
        extract_fim_part(tokenizer_fim.decode(tensor, skip_special_tokens=False)) for tensor in outputs
    ]


def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42):
    #set_seed(seed)
    
    if "<FILL-HERE>" in prompt:
        return fim_generation(prompt, max_new_tokens, temperature=0.2)
    else:
        completion = pipe(prompt, do_sample=True, top_p=0.95, temperature=temperature, max_new_tokens=max_new_tokens)[0]['generated_text']
        completion = completion[len(prompt):]
        return post_processing(prompt, completion)


demo = gr.Blocks(
    css=".gradio-container {background-color: #20233fff; color:white}"
)
with demo:
    with gr.Row():
        _, colum_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
        with colum_2:
            gr.Markdown(value=description)
            code = gr.Code(lines=5, language="python", label="Input code", value="def all_odd_elements(sequence):\n    \"\"\"Returns every odd element of the sequence.\"\"\"")
            
            with gr.Accordion("Advanced settings", open=False):
                max_new_tokens= gr.Slider(
                    minimum=8,
                    maximum=1024,
                    step=1,
                    value=48,
                    label="Number of tokens to generate",
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.5,
                    step=0.1,
                    value=0.2,
                    label="Temperature",
                )
                seed = gr.Slider(
                    minimum=0,
                    maximum=1000,
                    step=1,
                    label="Random seed to use for the generation"
                )
            run = gr.Button()
            output = gr.HTML(label="Generated code")

    event = run.click(code_generation, [code, max_new_tokens, temperature, seed], output, api_name="predict")
    gr.HTML(label="Contact", value="<img src='https://huggingface.co/datasets/bigcode/admin/resolve/main/bigcode_contact.png' alt='contact' style='display: block; margin: auto; max-width: 800px;'>")

demo.launch()