bart-summarizer / app.py
xTorch8's picture
Import torch
e417e31
import gradio as gr
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
MODEL = "xTorch8/fine-tuned-bart"
TOKEN = os.getenv("TOKEN")
MAX_TOKENS = 1024
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, token = TOKEN)
tokenizer = AutoTokenizer.from_pretrained(MODEL, token = TOKEN)
def summarize_text(text):
try:
chunk_size = MAX_TOKENS * 4
overlap = chunk_size // 4
step = chunk_size - overlap
chunks = [text[i:i + chunk_size] for i in range(0, len(text), step)]
summaries = []
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors = "pt", truncation = True, max_length = 1024, padding = True)
with torch.no_grad():
summary_ids = model.generate(
**inputs,
max_length = 1500,
length_penalty = 2.0,
num_beams = 4,
early_stopping = True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens = True)
summaries.append(summary)
final_text = " ".join(summaries)
summarization = final_text
if len(final_text) > MAX_TOKENS:
inputs = tokenizer(final_text, return_tensors = "pt", truncation = True, max_length = 1024, padding = True)
with torch.no_grad():
summary_ids = model.generate(
**inputs,
min_length = 300,
max_length = 1500,
length_penalty = 2.0,
num_beams = 4,
early_stopping = True
)
summarization = tokenizer.decode(summary_ids[0], skip_special_tokens = True)
else:
summarization = final_text
return summarization
except Exception as e:
return e
demo = gr.Interface(
fn = summarize_text,
inputs = gr.Textbox(lines = 20, label = "Input Text"),
outputs = "text",
title = "BART Summarizer"
)
if __name__ == "__main__":
demo.launch()