File size: 1,824 Bytes
eb96a65
93d5d8a
eb96a65
 
93d5d8a
 
 
eb96a65
 
38dd04a
 
eb96a65
38dd04a
eb96a65
 
 
 
 
 
38dd04a
eb96a65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from transformers import AutoTokenizer, AutoModelWithLMHead
from urllib.request import urlretrieve
import gradio as gr

# Loads latest model state from Github
urlretrieve("https://github.com/equ1/generative_python_transformer/tree/main/GPT-python")

# inference function
def inference(inp):
    tokenizer = AutoTokenizer.from_pretrained("GPT-python").to("cuda")
    model = AutoModelWithLMHead.from_pretrained("GPT-python").to("cuda")

    input_ids = tokenizer.encode(inp, return_tensors="pt").to("cuda")
    beam_output = model.generate(input_ids, 
                               max_length=512,
                               num_beams=10,
                               temperature=0.7,
                               no_repeat_ngram_size=5,
                               num_return_sequences=1,
                               ).to("cuda")
  
    output = []
    for beam in beam_output:
        out = tokenizer.decode(beam)
        fout = out.replace("<N>", "\n")
        output.append(fout)

    return '\n'.join(output)

desc = """

        Enter some Python code and click submit to see the model's autocompletion.\n

        

        Best results have been observed with the prompt of \"import\".\n



        Please note that outputs are reflective of a model trained on a measly 40 MBs of text data for 

        a single epoch of ~16 GPU hours. Given more data and training time, the autocompletion should be much stronger.\n

        

        Computation will take some time.

        """

# Creates and launches gradio interface
gr.Interface(fn=inference,
            inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
            outputs=gr.outputs.Textbox(),
            title="Generative Python Transformer",
            description=desc,
            ).launch()