Update app.py
Browse files
app.py
CHANGED
@@ -49,9 +49,9 @@ model_pegasus = PegasusForConditionalGeneration.from_pretrained(model_name).to(t
|
|
49 |
|
50 |
def get_max_str(lst):
|
51 |
return max(lst, key=len)
|
52 |
-
def get_response(input_text
|
53 |
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors='pt').to(torch_device)
|
54 |
-
translated = model_pegasus.generate(**batch,max_length=60,num_beams=
|
55 |
#num_beam_groups=num_beams, diversity_penalty=0.5
|
56 |
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
57 |
try:
|
@@ -77,7 +77,7 @@ def get_fun(txt):
|
|
77 |
|
78 |
txt_paraphrase=''
|
79 |
for phrase in tokens:
|
80 |
-
tmp=get_response(phrase
|
81 |
txt_paraphrase=txt_paraphrase+' '+tmp
|
82 |
return txt_paraphrase
|
83 |
|
|
|
49 |
|
50 |
def get_max_str(lst):
|
51 |
return max(lst, key=len)
|
52 |
+
def get_response(input_text):
|
53 |
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors='pt').to(torch_device)
|
54 |
+
translated = model_pegasus.generate(**batch,max_length=60,num_beams=15, num_return_sequences=15, temperature=1.5)
|
55 |
#num_beam_groups=num_beams, diversity_penalty=0.5
|
56 |
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
57 |
try:
|
|
|
77 |
|
78 |
txt_paraphrase=''
|
79 |
for phrase in tokens:
|
80 |
+
tmp=get_response(phrase)
|
81 |
txt_paraphrase=txt_paraphrase+' '+tmp
|
82 |
return txt_paraphrase
|
83 |
|