fafiz commited on
Commit
81ca650
1 Parent(s): 9a604b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -6,8 +6,8 @@ torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
  tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
7
  model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/parrot_paraphraser_on_T5")
8
  def get_response(input_text,num_return_sequences):
9
- batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=False,padding='longest',max_length=100, return_tensors="pt").to(torch_device)
10
- translated = model.generate(**batch,max_length=100,num_beams=10, num_return_sequences=num_return_sequences, temperature=1.9)
11
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
12
  return tgt_text
13
 
 
6
  tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
7
  model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/parrot_paraphraser_on_T5")
8
  def get_response(input_text,num_return_sequences):
9
+ batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=100, return_tensors="pt").to(torch_device)
10
+ translated = model.generate(**batch,max_length=100,num_beams=10, num_return_sequences=num_return_sequences, temperature=0.9)
11
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
12
  return tgt_text
13