File size: 3,837 Bytes
3d9b66a
 
eb4b465
3d9b66a
eb4b465
 
3d9b66a
 
 
 
 
 
 
8e9c357
dac8d18
3d9b66a
 
 
7f39fb2
3d9b66a
d797f28
7f39fb2
d797f28
7f39fb2
3d9b66a
 
2b8b93a
 
 
 
3d9b66a
2636ace
3d9b66a
 
 
 
ac62d88
 
 
 
eb4b465
e1e1d6a
eb4b465
 
9b11bf4
eb4b465
 
 
 
 
 
 
 
 
 
 
 
 
ac62d88
 
 
e01d0a4
3d9b66a
8e9c357
3d9b66a
2b8b93a
 
3d9b66a
 
 
2636ace
 
8e9c357
eb4b465
3d9b66a
8e9c357
 
 
 
e01d0a4
d4533bc
8e9c357
 
e01d0a4
d4533bc
8e9c357
d4533bc
3d9b66a
 
 
 
 
7f39fb2
 
 
 
8e9c357
 
 
7f39fb2
3d9b66a
 
7f39fb2
 
3d9b66a
2636ace
3d9b66a
eb4b465
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch
import ecco
import requests
from transformers import AutoTokenizer
from torch.nn import functional as F

header = """
import psycopg2

conn = psycopg2.connect("CONN")
cur = conn.cursor()

MIDDLE
def rename_customer(id, newName):\n\t# PROMPT\n\tcur.execute("UPDATE customer SET name =
"""

modelPath = {
    # "GPT2-Medium": "gpt2-medium",
    "CodeParrot-mini": "codeparrot/codeparrot-small",
    # "CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
    # "GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B",
    "CodeParrot": "codeparrot/codeparrot",
    # "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
}

preloadModels = {}
for m in list(modelPath.keys()):
    preloadModels[m] = ecco.from_pretrained(modelPath[m])

def generation(tokenizer, model, content):
    decoder = 'Standard'
    num_beams = 2 if decoder == 'Beam' else None
    typical_p = 0.8 if decoder == 'Typical' else None
    do_sample = (decoder in ['Beam', 'Typical', 'Sample'])

    seek_token_ids = [
        tokenizer.encode('= \'" +')[1:],
        tokenizer.encode('= " +')[1:],
    ]

    full_output = model.generate(content, generate=6, do_sample=False)

    def next_words(code, position, seek_token_ids):
        op_model = model.generate(code, generate=1, do_sample=False)
        hidden_states = op_model.hidden_states
        layer_no = len(hidden_states) - 1
        h = hidden_states[-1]
        hidden_state = h[position - 1]
        logits = op_model.lm_head(op_model.to(hidden_state))
        softmax = F.softmax(logits, dim=-1)
        my_token_prob = softmax[seek_token_ids[0]]

        if len(seek_token_ids) > 1:
            newprompt = code + tokenizer.decode(seek_token_ids[0])
            return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
        return my_token_prob

    prob = 0
    for opt in seek_token_ids:
        prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
    return ["".join(full_output.tokens), str(prob.item() * 100) + '% chance of risky concatenation']

def code_from_prompts(prompt, model, type_hints, pre_content):
    tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
    # model = ecco.from_pretrained(modelPath[model])
    model = preloadModels[model]

    code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)

    if type_hints:
        code = code.replace('id,', 'id: int,')
        code = code.replace('id)', 'id: int)')
        code = code.replace('newName)', 'newName: str) -> None')

    if pre_content == 'None':
        code = code.replace('MIDDLE\n', '')
    elif 'Concatenation' in pre_content:
        code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = ' + str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
    elif 'composition' in pre_content:
        code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', str(id))\n\treturn cur.fetchall()
""".strip() + "\n")

    results = generation(tokenizer, model, code)
    return results

iface = gr.Interface(
    fn=code_from_prompts,
	inputs=[
        gr.components.Textbox(label="Insert comment"),
        gr.components.Radio(list(modelPath.keys()), label="Code Model"),
        gr.components.Checkbox(label="Include type hints"),
        gr.components.Radio([
            "None",
            "Proper composition: Include function 'WHERE id = %s'",
            "Concatenation: Include a function with 'WHERE id = ' + id",
        ], label="Has user already written a function?")
    ],
	outputs=[
        gr.components.Textbox(label="Most probable code"),
        gr.components.Textbox(label="Probability of concat"),
    ],
	description="Prompt the code model to write a SQL query with string concatenation.",
)
iface.launch()