Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,53 @@
|
|
1 |
-
import gradio
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio
|
2 |
+
from gpt import get_model
|
3 |
|
4 |
+
|
5 |
+
model_small, tokenizer_small = get_model("gpt2")
|
6 |
+
model_large, tokenizer_large = get_model("gpt2-large")
|
7 |
+
|
8 |
+
def predict(inp, model_type):
|
9 |
+
if model_type == "gpt2-large":
|
10 |
+
model, tokenizer = model_large, tokenizer_large
|
11 |
+
input_ids = tokenizer.encode(inp, return_tensors='tf')
|
12 |
+
beam_output = model.generate(input_ids, max_length=45, num_beams=5,
|
13 |
+
no_repeat_ngram_size=2,
|
14 |
+
early_stopping=True)
|
15 |
+
output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
|
16 |
+
clean_up_tokenization_spaces=True)
|
17 |
+
else:
|
18 |
+
model, tokenizer = model_small, tokenizer_small
|
19 |
+
input_ids = tokenizer.encode(inp, return_tensors='tf')
|
20 |
+
beam_output = model.generate(input_ids, max_length=60, num_beams=5,
|
21 |
+
no_repeat_ngram_size=2, early_stopping=True)
|
22 |
+
output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
|
23 |
+
clean_up_tokenization_spaces=True)
|
24 |
+
if output.count(".") >= 2:
|
25 |
+
output = ".".join(output.split(".")[:-1]) + "."
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
INPUTS = [gradio.inputs.Textbox(lines=2, label="Input Text"),
|
31 |
+
gradio.inputs.Radio(choices=["gpt2-small", "gpt2-large"],
|
32 |
+
label="Model Size")]
|
33 |
+
OUTPUTS = gradio.outputs.Textbox()
|
34 |
+
examples = [
|
35 |
+
["The toughest thing about software engineering is", "gpt2-large"],
|
36 |
+
["Is this the real life? Is this just fantasy?", "gpt2-small"]
|
37 |
+
]
|
38 |
+
INTERFACE = gradio.Interface(fn=predict, inputs=INPUTS, outputs=OUTPUTS, title="GPT-2",
|
39 |
+
description="GPT-2 is a large transformer-based language "
|
40 |
+
"model with 1.5 billion parameters, trained on "
|
41 |
+
"a dataset of 8 million web pages. GPT-2 is "
|
42 |
+
"trained with a simple objective: predict the "
|
43 |
+
"next word, given all of the previous words "
|
44 |
+
"within some text. You can configure small vs "
|
45 |
+
"large below: the large model takes longer to "
|
46 |
+
"run ("
|
47 |
+
"55s vs 30s) "
|
48 |
+
"but generates better text.",
|
49 |
+
thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true",
|
50 |
+
examples=examples,
|
51 |
+
capture_session=False)
|
52 |
+
|
53 |
+
INTERFACE.launch(inbrowser=True)
|