Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
MODEL = "xTorch8/fine-tuned-bart" | |
TOKEN = os.getenv("TOKEN") | |
MAX_TOKENS = 1024 | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, token = TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL, token = TOKEN) | |
def summarize_text(text): | |
try: | |
chunk_size = MAX_TOKENS * 4 | |
overlap = chunk_size // 4 | |
step = chunk_size - overlap | |
chunks = [text[i:i + chunk_size] for i in range(0, len(text), step)] | |
summaries = [] | |
for chunk in chunks: | |
inputs = tokenizer(chunk, return_tensors = "pt", truncation = True, max_length = 1024, padding = True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**inputs, | |
max_length = 1500, | |
length_penalty = 2.0, | |
num_beams = 4, | |
early_stopping = True | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens = True) | |
summaries.append(summary) | |
final_text = " ".join(summaries) | |
summarization = final_text | |
if len(final_text) > MAX_TOKENS: | |
inputs = tokenizer(final_text, return_tensors = "pt", truncation = True, max_length = 1024, padding = True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**inputs, | |
min_length = 300, | |
max_length = 1500, | |
length_penalty = 2.0, | |
num_beams = 4, | |
early_stopping = True | |
) | |
summarization = tokenizer.decode(summary_ids[0], skip_special_tokens = True) | |
else: | |
summarization = final_text | |
return summarization | |
except Exception as e: | |
return e | |
demo = gr.Interface( | |
fn = summarize_text, | |
inputs = gr.Textbox(lines = 20, label = "Input Text"), | |
outputs = "text", | |
title = "BART Summarizer" | |
) | |
if __name__ == "__main__": | |
demo.launch() |