|
import torch |
|
import gradio as gr |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
TOKENIZER_NAME = "unicamp-dl/ptt5-base-portuguese-vocab" |
|
MODEL_NAME = "recogna-nlp/ptt5-base-summ-wikilingua" |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(TOKENIZER_NAME) |
|
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) |
|
|
|
|
|
TARGET_LENGTH = 256 |
|
MARGIN = 6 |
|
MIN_LENGTH = TARGET_LENGTH - MARGIN |
|
MAX_LENGTH = TARGET_LENGTH + MARGIN |
|
MAX_ATTEMPTS = 5 |
|
|
|
|
|
def summarize_text(text): |
|
""" |
|
Gera um resumo do texto dentro da faixa desejada (250 a 262 caracteres). |
|
Se não atingir esse intervalo, ajusta proporcionalmente até 3 tentativas. |
|
""" |
|
adjusted_target_length = TARGET_LENGTH |
|
best_summary = "" |
|
best_distance = float("inf") |
|
|
|
for _ in range(MAX_ATTEMPTS): |
|
|
|
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True) |
|
|
|
|
|
summary_ids = model.generate( |
|
**inputs, |
|
max_length=adjusted_target_length, |
|
min_length=32, |
|
num_beams=5, |
|
no_repeat_ngram_size=3, |
|
early_stopping=True, |
|
) |
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
summary_length = len(summary) |
|
distance = abs(TARGET_LENGTH - summary_length) |
|
|
|
|
|
if distance < best_distance: |
|
best_summary = summary |
|
best_distance = distance |
|
|
|
|
|
if MIN_LENGTH <= summary_length <= MAX_LENGTH: |
|
return summary |
|
|
|
|
|
error_percent = (summary_length - TARGET_LENGTH) / TARGET_LENGTH |
|
adjustment = int(adjusted_target_length * error_percent) |
|
adjusted_target_length -= adjustment |
|
|
|
|
|
return best_summary |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=summarize_text, |
|
inputs=gr.Textbox(label="Texto de Entrada", lines=10, placeholder="Digite ou cole seu texto aqui..."), |
|
outputs=gr.Textbox(label="Resumo Gerado"), |
|
title="Resumidor de Textos com PTT5-SUMM-WIKILINGUA", |
|
description="Insira um texto e receba um resumo dentro do intervalo de 250 a 262 caracteres.", |
|
) |
|
|
|
if __name__=="__main__": |
|
interface.launch(share=True) |
|
|