Update app.py
Browse files
app.py
CHANGED
@@ -8,16 +8,16 @@ model_large, tokenizer_large = get_model("gpt2-large")
|
|
8 |
def predict(inp, model_type):
|
9 |
if model_type == "gpt2-large":
|
10 |
model, tokenizer = model_large, tokenizer_large
|
11 |
-
input_ids = tokenizer.encode(inp, return_tensors='tf')
|
12 |
-
beam_output = model.generate(input_ids, max_length=
|
13 |
no_repeat_ngram_size=2,
|
14 |
early_stopping=True)
|
15 |
output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
|
16 |
clean_up_tokenization_spaces=True)
|
17 |
else:
|
18 |
model, tokenizer = model_small, tokenizer_small
|
19 |
-
input_ids = tokenizer.encode(inp, return_tensors='tf')
|
20 |
-
beam_output = model.generate(input_ids, max_length=
|
21 |
no_repeat_ngram_size=2, early_stopping=True)
|
22 |
output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
|
23 |
clean_up_tokenization_spaces=True)
|
@@ -35,7 +35,7 @@ examples = [
|
|
35 |
["The toughest thing about software engineering is", "gpt2-large"],
|
36 |
["Is this the real life? Is this just fantasy?", "gpt2-small"]
|
37 |
]
|
38 |
-
INTERFACE = gradio.Interface(fn=predict, inputs=INPUTS, outputs=OUTPUTS, title="GPT-2",
|
39 |
description="GPT-2 is a large transformer-based language "
|
40 |
"model with 1.5 billion parameters, trained on "
|
41 |
"a dataset of 8 million web pages. GPT-2 is "
|
|
|
8 |
def predict(inp, model_type):
|
9 |
if model_type == "gpt2-large":
|
10 |
model, tokenizer = model_large, tokenizer_large
|
11 |
+
input_ids = tokenizer.encode("user:"+inp+"\nai:", return_tensors='tf')
|
12 |
+
beam_output = model.generate(input_ids, max_length=60, num_beams=5,
|
13 |
no_repeat_ngram_size=2,
|
14 |
early_stopping=True)
|
15 |
output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
|
16 |
clean_up_tokenization_spaces=True)
|
17 |
else:
|
18 |
model, tokenizer = model_small, tokenizer_small
|
19 |
+
input_ids = tokenizer.encode("user:"+inp+"\nai:", return_tensors='tf')
|
20 |
+
beam_output = model.generate(input_ids, max_length=180, num_beams=5,
|
21 |
no_repeat_ngram_size=2, early_stopping=True)
|
22 |
output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
|
23 |
clean_up_tokenization_spaces=True)
|
|
|
35 |
["The toughest thing about software engineering is", "gpt2-large"],
|
36 |
["Is this the real life? Is this just fantasy?", "gpt2-small"]
|
37 |
]
|
38 |
+
INTERFACE = gradio.Interface(fn=predict, inputs=INPUTS, outputs=OUTPUTS, title="Chat GPT-2",
|
39 |
description="GPT-2 is a large transformer-based language "
|
40 |
"model with 1.5 billion parameters, trained on "
|
41 |
"a dataset of 8 million web pages. GPT-2 is "
|