File size: 2,058 Bytes
e217c21
 
5bc994d
e217c21
 
 
 
6b82b4f
e217c21
909674f
 
e217c21
 
 
 
 
 
 
 
 
 
8f13478
e217c21
 
 
8f13478
e217c21
 
909674f
e217c21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio
from gpt import get_model


model_small, tokenizer_small = get_model("gpt2")

def predict(inp, model_type):
    if model_type == "gpt2-small":
        model, tokenizer = model_small, tokenizer_small
        input_ids = tokenizer.encode("user:"+inp+"\nai:", return_tensors='tf')
        beam_output = model.generate(input_ids, max_length=180, num_beams=5,
                                 no_repeat_ngram_size=2, early_stopping=True)
        output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
                              clean_up_tokenization_spaces=True)
    if output.count(".") >= 2:
        output = ".".join(output.split(".")[:-1]) + "."
    return output

 

INPUTS = [gradio.inputs.Textbox(lines=2, label="Input Text"),
            gradio.inputs.Radio(choices=["gpt2-small"],
                                label="Model Size")]
OUTPUTS = gradio.outputs.Textbox()
examples = [
    ["The toughest thing about software engineering is", "gpt2-small"],
    ["Is this the real life? Is this just fantasy?", "gpt2-small"]
]
INTERFACE = gradio.Interface(fn=predict, inputs=INPUTS, outputs=OUTPUTS, title="Chat GPT-2",
                 description="GPT-2 is a large transformer-based language "
                             "model with 1.5 billion parameters, trained on "
                             "a dataset of 8 million web pages. GPT-2 is "
                             "trained with a simple objective: predict the "
                             "next word, given all of the previous words "
                             "within some text. You can configure small vs "
                             "large below: the large model takes longer to "
                             "run ("
                             "55s vs 30s) "
                             "but generates better text.",
                 thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true",
                             examples=examples,
                 capture_session=False)

INTERFACE.launch(inbrowser=True)