Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -14,9 +14,10 @@ prompter_model, prompter_tokenizer = load_prompter()
14
  def generate(plain_text):
15
  input_ids = prompter_tokenizer(plain_text.strip()+" Rephrase:", return_tensors="pt").input_ids
16
  eos_id = prompter_tokenizer.eos_token_id
17
- outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=8, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
18
- output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
19
- res = output_texts[0].replace(plain_text+" Rephrase:", "").strip()
 
20
  return res
21
 
22
  txt = grad.Textbox(lines=1, label="Initial Text", placeholder="Input Prompt")
 
14
  def generate(plain_text):
15
  input_ids = prompter_tokenizer(plain_text.strip()+" Rephrase:", return_tensors="pt").input_ids
16
  eos_id = prompter_tokenizer.eos_token_id
17
+ # Just use 1 beam and get 1 output, this is much, much, much faster than 8 beams and 8 outputs and we're only using the first.
18
+ outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
19
+ # Use [input_ids.shape[-1]:] because the decoded tokenised version of plain_text may have a different number of characters to the original
20
+ res = tokenizer.decode(outputs[0][input_ids.shape[-1]:])
21
  return res
22
 
23
  txt = grad.Textbox(lines=1, label="Initial Text", placeholder="Input Prompt")