Vaishakhh commited on
Commit
e45f695
1 Parent(s): cb694c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -38,7 +38,7 @@ device= "cuda:0"
38
  adequacy_threshold = 0.90
39
  fluency_threshold = 0.90
40
  diversity_ranker="levenshtein"
41
-
42
  model_name = 'tuner007/pegasus_paraphrase'
43
  torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
  tokenizer = PegasusTokenizer.from_pretrained(model_name)
@@ -47,8 +47,8 @@ model_pegasus = PegasusForConditionalGeneration.from_pretrained(model_name).to(t
47
  def get_max_str(lst):
48
  return max(lst, key=len)
49
  def get_response(input_text,num_return_sequences=10,num_beams=10):
50
- batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60,return_tensors='pt').to(torch_device)
51
- translated = model_pegasus.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
52
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
53
  try:
54
  adequacy_filtered_phrases = adequacy_score.filter(input_text,tgt_text, adequacy_threshold, device)
@@ -71,7 +71,7 @@ def get_fun(txt):
71
 
72
  txt_paraphrase=''
73
  for phrase in tokens:
74
- tmp=get_response(phrase,num_return_sequences=10,num_beams=10)
75
  txt_paraphrase=txt_paraphrase+' '+tmp
76
  return txt_paraphrase
77
 
 
38
  adequacy_threshold = 0.90
39
  fluency_threshold = 0.90
40
  diversity_ranker="levenshtein"
41
+ do_diverse=True
42
  model_name = 'tuner007/pegasus_paraphrase'
43
  torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
  tokenizer = PegasusTokenizer.from_pretrained(model_name)
 
47
  def get_max_str(lst):
48
  return max(lst, key=len)
49
  def get_response(input_text,num_return_sequences=10,num_beams=10):
50
+ batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=30,return_tensors='pt').to(torch_device)
51
+ translated = model_pegasus.generate(**batch,max_length=30,num_beams=num_beams, num_return_sequences=num_return_sequences, num_beam_groups=num_beams, diversity_penalty=0.5, temperature=1.5)
52
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
53
  try:
54
  adequacy_filtered_phrases = adequacy_score.filter(input_text,tgt_text, adequacy_threshold, device)
 
71
 
72
  txt_paraphrase=''
73
  for phrase in tokens:
74
+ tmp=get_response(phrase,num_return_sequences=30,num_beams=30)
75
  txt_paraphrase=txt_paraphrase+' '+tmp
76
  return txt_paraphrase
77