timaaos2 commited on
Commit
e217c21
1 Parent(s): 5bc994d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -2
app.py CHANGED
@@ -1,3 +1,53 @@
1
- import gradio as gr
 
2
 
3
- gr.Interface.load("models/gpt2").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ from gpt import get_model
3
 
4
+
5
+ model_small, tokenizer_small = get_model("gpt2")
6
+ model_large, tokenizer_large = get_model("gpt2-large")
7
+
8
+ def predict(inp, model_type):
9
+ if model_type == "gpt2-large":
10
+ model, tokenizer = model_large, tokenizer_large
11
+ input_ids = tokenizer.encode(inp, return_tensors='tf')
12
+ beam_output = model.generate(input_ids, max_length=45, num_beams=5,
13
+ no_repeat_ngram_size=2,
14
+ early_stopping=True)
15
+ output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
16
+ clean_up_tokenization_spaces=True)
17
+ else:
18
+ model, tokenizer = model_small, tokenizer_small
19
+ input_ids = tokenizer.encode(inp, return_tensors='tf')
20
+ beam_output = model.generate(input_ids, max_length=60, num_beams=5,
21
+ no_repeat_ngram_size=2, early_stopping=True)
22
+ output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
23
+ clean_up_tokenization_spaces=True)
24
+ if output.count(".") >= 2:
25
+ output = ".".join(output.split(".")[:-1]) + "."
26
+ return output
27
+
28
+
29
+
30
+ INPUTS = [gradio.inputs.Textbox(lines=2, label="Input Text"),
31
+ gradio.inputs.Radio(choices=["gpt2-small", "gpt2-large"],
32
+ label="Model Size")]
33
+ OUTPUTS = gradio.outputs.Textbox()
34
+ examples = [
35
+ ["The toughest thing about software engineering is", "gpt2-large"],
36
+ ["Is this the real life? Is this just fantasy?", "gpt2-small"]
37
+ ]
38
+ INTERFACE = gradio.Interface(fn=predict, inputs=INPUTS, outputs=OUTPUTS, title="GPT-2",
39
+ description="GPT-2 is a large transformer-based language "
40
+ "model with 1.5 billion parameters, trained on "
41
+ "a dataset of 8 million web pages. GPT-2 is "
42
+ "trained with a simple objective: predict the "
43
+ "next word, given all of the previous words "
44
+ "within some text. You can configure small vs "
45
+ "large below: the large model takes longer to "
46
+ "run ("
47
+ "55s vs 30s) "
48
+ "but generates better text.",
49
+ thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true",
50
+ examples=examples,
51
+ capture_session=False)
52
+
53
+ INTERFACE.launch(inbrowser=True)