timaaos2 commited on
Commit
909674f
1 Parent(s): a60723a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -8,16 +8,16 @@ model_large, tokenizer_large = get_model("gpt2-large")
8
  def predict(inp, model_type):
9
  if model_type == "gpt2-large":
10
  model, tokenizer = model_large, tokenizer_large
11
- input_ids = tokenizer.encode(inp, return_tensors='tf')
12
- beam_output = model.generate(input_ids, max_length=45, num_beams=5,
13
  no_repeat_ngram_size=2,
14
  early_stopping=True)
15
  output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
16
  clean_up_tokenization_spaces=True)
17
  else:
18
  model, tokenizer = model_small, tokenizer_small
19
- input_ids = tokenizer.encode(inp, return_tensors='tf')
20
- beam_output = model.generate(input_ids, max_length=60, num_beams=5,
21
  no_repeat_ngram_size=2, early_stopping=True)
22
  output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
23
  clean_up_tokenization_spaces=True)
@@ -35,7 +35,7 @@ examples = [
35
  ["The toughest thing about software engineering is", "gpt2-large"],
36
  ["Is this the real life? Is this just fantasy?", "gpt2-small"]
37
  ]
38
- INTERFACE = gradio.Interface(fn=predict, inputs=INPUTS, outputs=OUTPUTS, title="GPT-2",
39
  description="GPT-2 is a large transformer-based language "
40
  "model with 1.5 billion parameters, trained on "
41
  "a dataset of 8 million web pages. GPT-2 is "
 
8
  def predict(inp, model_type):
9
  if model_type == "gpt2-large":
10
  model, tokenizer = model_large, tokenizer_large
11
+ input_ids = tokenizer.encode("user:"+inp+"\nai:", return_tensors='tf')
12
+ beam_output = model.generate(input_ids, max_length=60, num_beams=5,
13
  no_repeat_ngram_size=2,
14
  early_stopping=True)
15
  output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
16
  clean_up_tokenization_spaces=True)
17
  else:
18
  model, tokenizer = model_small, tokenizer_small
19
+ input_ids = tokenizer.encode("user:"+inp+"\nai:", return_tensors='tf')
20
+ beam_output = model.generate(input_ids, max_length=180, num_beams=5,
21
  no_repeat_ngram_size=2, early_stopping=True)
22
  output = tokenizer.decode(beam_output[0], skip_special_tokens=True,
23
  clean_up_tokenization_spaces=True)
 
35
  ["The toughest thing about software engineering is", "gpt2-large"],
36
  ["Is this the real life? Is this just fantasy?", "gpt2-small"]
37
  ]
38
+ INTERFACE = gradio.Interface(fn=predict, inputs=INPUTS, outputs=OUTPUTS, title="Chat GPT-2",
39
  description="GPT-2 is a large transformer-based language "
40
  "model with 1.5 billion parameters, trained on "
41
  "a dataset of 8 million web pages. GPT-2 is "