Update app.py
Browse files
app.py
CHANGED
@@ -38,7 +38,7 @@ device= "cuda:0"
|
|
38 |
adequacy_threshold = 0.90
|
39 |
fluency_threshold = 0.90
|
40 |
diversity_ranker="levenshtein"
|
41 |
-
|
42 |
model_name = 'tuner007/pegasus_paraphrase'
|
43 |
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
44 |
tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
@@ -47,8 +47,8 @@ model_pegasus = PegasusForConditionalGeneration.from_pretrained(model_name).to(t
|
|
47 |
def get_max_str(lst):
|
48 |
return max(lst, key=len)
|
49 |
def get_response(input_text,num_return_sequences=10,num_beams=10):
|
50 |
-
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=
|
51 |
-
translated = model_pegasus.generate(**batch,max_length=
|
52 |
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
53 |
try:
|
54 |
adequacy_filtered_phrases = adequacy_score.filter(input_text,tgt_text, adequacy_threshold, device)
|
@@ -71,7 +71,7 @@ def get_fun(txt):
|
|
71 |
|
72 |
txt_paraphrase=''
|
73 |
for phrase in tokens:
|
74 |
-
tmp=get_response(phrase,num_return_sequences=
|
75 |
txt_paraphrase=txt_paraphrase+' '+tmp
|
76 |
return txt_paraphrase
|
77 |
|
|
|
38 |
adequacy_threshold = 0.90
|
39 |
fluency_threshold = 0.90
|
40 |
diversity_ranker="levenshtein"
|
41 |
+
do_diverse=True
|
42 |
model_name = 'tuner007/pegasus_paraphrase'
|
43 |
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
44 |
tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
|
|
47 |
def get_max_str(lst):
|
48 |
return max(lst, key=len)
|
49 |
def get_response(input_text,num_return_sequences=10,num_beams=10):
|
50 |
+
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=30,return_tensors='pt').to(torch_device)
|
51 |
+
translated = model_pegasus.generate(**batch,max_length=30,num_beams=num_beams, num_return_sequences=num_return_sequences, num_beam_groups=num_beams, diversity_penalty=0.5, temperature=1.5)
|
52 |
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
53 |
try:
|
54 |
adequacy_filtered_phrases = adequacy_score.filter(input_text,tgt_text, adequacy_threshold, device)
|
|
|
71 |
|
72 |
txt_paraphrase=''
|
73 |
for phrase in tokens:
|
74 |
+
tmp=get_response(phrase,num_return_sequences=30,num_beams=30)
|
75 |
txt_paraphrase=txt_paraphrase+' '+tmp
|
76 |
return txt_paraphrase
|
77 |
|