timaaos2 commited on
Commit
6b82b4f
1 Parent(s): 8f13478

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -5,7 +5,7 @@ from gpt import get_model
5
  model_small, tokenizer_small = get_model("gpt2")
6
 
7
  def predict(inp, model_type):
8
- if model_type == "gpt2":
9
  model, tokenizer = model_small, tokenizer_small
10
  input_ids = tokenizer.encode("user:"+inp+"\nai:", return_tensors='tf')
11
  beam_output = model.generate(input_ids, max_length=180, num_beams=5,
5
  model_small, tokenizer_small = get_model("gpt2")
6
 
7
  def predict(inp, model_type):
8
+ if model_type == "gpt2-small":
9
  model, tokenizer = model_small, tokenizer_small
10
  input_ids = tokenizer.encode("user:"+inp+"\nai:", return_tensors='tf')
11
  beam_output = model.generate(input_ids, max_length=180, num_beams=5,