Vaishakhh commited on
Commit
9d66f2e
1 Parent(s): c93d2b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -49,9 +49,9 @@ model_pegasus = PegasusForConditionalGeneration.from_pretrained(model_name).to(t
49
 
50
  def get_max_str(lst):
51
  return max(lst, key=len)
52
- def get_response(input_text,num_return_sequences=10,num_beams=10):
53
  batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors='pt').to(torch_device)
54
- translated = model_pegasus.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
55
  #num_beam_groups=num_beams, diversity_penalty=0.5
56
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
57
  try:
@@ -77,7 +77,7 @@ def get_fun(txt):
77
 
78
  txt_paraphrase=''
79
  for phrase in tokens:
80
- tmp=get_response(phrase,num_return_sequences=15,num_beams=15)
81
  txt_paraphrase=txt_paraphrase+' '+tmp
82
  return txt_paraphrase
83
 
 
49
 
50
  def get_max_str(lst):
51
  return max(lst, key=len)
52
+ def get_response(input_text):
53
  batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors='pt').to(torch_device)
54
+ translated = model_pegasus.generate(**batch,max_length=60,num_beams=15, num_return_sequences=15, temperature=1.5)
55
  #num_beam_groups=num_beams, diversity_penalty=0.5
56
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
57
  try:
 
77
 
78
  txt_paraphrase=''
79
  for phrase in tokens:
80
+ tmp=get_response(phrase)
81
  txt_paraphrase=txt_paraphrase+' '+tmp
82
  return txt_paraphrase
83