text-summarizer / app.py
Medhu's picture
Initial text summariser app
ebc661a
# app.py
# Gradio multi-model summarizer (two-model outputs). Uses only Hugging Face models.
# Supports URL or pasted text input, and Paragraph / Points output formats.
import os, time, re
from typing import List, Optional, Tuple
import torch
import trafilatura
import gradio as gr
from transformers import AutoTokenizer, pipeline
# ---------------- Helper functions ----------------
def fetch_article(url: str) -> str:
downloaded = trafilatura.fetch_url(url)
if not downloaded:
raise ValueError("Failed to download URL (check link or network).")
text = trafilatura.extract(downloaded, include_comments=False, favor_recall=True)
if not text:
raise ValueError("Could not extract main article text.")
return clean_text(text)
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text).strip()
sentences = [s for s in re.split(r'(?<=[.!?])\s+', text) if len(s) > 3]
return " ".join(sentences)
def to_bullet_points(summary_text: str) -> str:
sentences = re.split(r'(?<=[.!?])\s+', summary_text.strip())
sentences = [s.strip() for s in sentences if len(s.strip()) > 3]
return "\n".join([f"- {s}" for s in sentences])
# ---------------- Robust token-level chunking ----------------
def _effective_max_tokens(tokenizer, default: int = 1024) -> int:
m = getattr(tokenizer, "model_max_length", None)
if m is None:
return default
try:
m = int(m)
except Exception:
return default
if m <= 0 or m > 1_000_000:
return default
return min(m, default)
def chunk_by_tokens_safe(text: str, tokenizer, max_tokens: int, overlap: int = 64, max_chunks: int = 20) -> List[str]:
ids = tokenizer.encode(text, add_special_tokens=False, truncation=False)
if not ids:
return []
safe_max = max_tokens
step = max(safe_max - overlap, 1)
chunks = []
start = 0
while start < len(ids) and len(chunks) < max_chunks:
end = min(start + safe_max, len(ids))
chunk_ids = ids[start:end]
if len(chunk_ids) > safe_max:
chunk_ids = chunk_ids[:safe_max]
chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True))
if end >= len(ids):
break
start += step
return chunks
# ---------------- Hugging Face summarizer wrapper ----------------
class HFSummarizer:
def __init__(self, model_name: str, device: Optional[int] = None,
max_input_tokens: Optional[int] = None,
per_chunk_new_tokens: int = 150,
per_chunk_min_new_tokens: int = 50,
reduce_new_tokens: int = 200,
beams: int = 4):
if device is None:
device = 0 if torch.cuda.is_available() else -1
self.device = device
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
model_max = _effective_max_tokens(self.tokenizer, default=1024)
if max_input_tokens is None:
self.max_input_tokens = max(128, model_max - 4)
else:
self.max_input_tokens = min(max_input_tokens, model_max - 4)
self.per_chunk_new_tokens = per_chunk_new_tokens
self.per_chunk_min_new_tokens = per_chunk_min_new_tokens
self.reduce_new_tokens = reduce_new_tokens
self.beams = beams
# create HF pipeline (lazy enough; will download model the first time)
self.pipe = pipeline("summarization", model=self.model_name, tokenizer=self.tokenizer, device=self.device)
def _prepare_for_model(self, text: str) -> str:
# Some models (T5 family) prefer "summarize: " prefix — pipeline often handles it,
# but giving it explicitly improves results for T5.
if "t5" in self.model_name.lower():
return "summarize: " + text
return text
def _summarize_once(self, text: str, max_new: int, min_new: int) -> str:
prepared = self._prepare_for_model(text)
tok_len = len(self.tokenizer.encode(prepared, add_special_tokens=True, truncation=False))
print(f" -> model '{self.model_name}', tokens (with special tokens): {tok_len}")
out = self.pipe(prepared, num_beams=self.beams, do_sample=False, length_penalty=1.0,
early_stopping=True, max_new_tokens=max_new, min_new_tokens=min_new, truncation=True)[0]["summary_text"]
return out.strip()
def summarize(self, text: str) -> str:
text = clean_text(text)
ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=False)
if len(ids) <= self.max_input_tokens:
return self._summarize_once(text, self.per_chunk_new_tokens, self.per_chunk_min_new_tokens)
chunks = chunk_by_tokens_safe(text, self.tokenizer, max_tokens=self.max_input_tokens, overlap=64, max_chunks=20)
if not chunks:
return ""
partials = []
for c in chunks:
partials.append(self._summarize_once(c, self.per_chunk_new_tokens, self.per_chunk_min_new_tokens))
stitched = " ".join(partials)
final = self._summarize_once(stitched, self.reduce_new_tokens, min(80, self.reduce_new_tokens // 3))
return final
# ---------------- Multi-model coordinator ----------------
class MultiHFSummarizer:
def __init__(self, models: List[str] = None):
if models is None:
models = ["sshleifer/distilbart-cnn-12-6", "facebook/bart-large-cnn"]
self.models = models
self._instances = {}
def _get_inst(self, model_name: str) -> HFSummarizer:
if model_name not in self._instances:
self._instances[model_name] = HFSummarizer(model_name=model_name)
return self._instances[model_name]
def summarize_text(self, text: str) -> List[Tuple[str, str]]:
results = []
for m in self.models:
print(f"\nRunning model: {m}")
inst = self._get_inst(m)
t0 = time.time()
s = inst.summarize(text)
elapsed = round(time.time() - t0, 2)
print(f"Model {m} finished in {elapsed}s")
results.append((m, s))
return results
def summarize_url(self, url: str) -> List[Tuple[str, str]]:
text = fetch_article(url)
return self.summarize_text(text)
# ---------------- Gradio UI logic ----------------
MODEL_OPTIONS = {
"DistilBART + BART-large": ["sshleifer/distilbart-cnn-12-6", "facebook/bart-large-cnn"],
"DistilBART + Pegasus": ["sshleifer/distilbart-cnn-12-6", "google/pegasus-cnn_dailymail"],
"DistilBART + T5-small": ["sshleifer/distilbart-cnn-12-6", "t5-small"],
}
def summarize_ui(input_type: str, input_value: str, model_choice: str, out_format: str):
# get text
try:
if input_type == "URL":
text = fetch_article(input_value)
else:
text = input_value
if not text or len(text.strip()) == 0:
return "No text found. Please paste text or check the URL."
except Exception as e:
return f"Error fetching input: {e}"
models = MODEL_OPTIONS.get(model_choice, MODEL_OPTIONS["DistilBART + BART-large"])
# Warn if user selected a heavy model on CPU
warning = ""
if (not torch.cuda.is_available()) and any("bart-large" in m for m in models):
warning = ("**Warning:** You're running on CPU. `facebook/bart-large-cnn` is large and may run out of memory "
"or be slow. Consider choosing a lighter pair (DistilBART + T5-small) or request GPU in Space settings.\n\n")
coordinator = MultiHFSummarizer(models=models)
outputs = coordinator.summarize_text(text)
md = warning
for model_name, summary in outputs:
md += f"### {model_name}\n\n"
if out_format == "Points":
md += to_bullet_points(summary) + "\n\n"
else:
md += summary + "\n\n"
return md
# ---------------- Build Gradio interface ----------------
with gr.Blocks(title="Multi-Model Summarizer") as demo:
gr.Markdown("# Multi-Model Summarizer (Hugging Face models)\nChoose input, model pair, and output format (paragraph or points).")
with gr.Row():
input_type = gr.Radio(["URL", "Text"], value="URL", label="Input type")
model_choice = gr.Dropdown(list(MODEL_OPTIONS.keys()), value="DistilBART + BART-large", label="Model pair")
out_format = gr.Dropdown(["Paragraph", "Points"], value="Paragraph", label="Output format")
input_value = gr.Textbox(lines=6, placeholder="Paste article URL or text here...")
run_btn = gr.Button("Summarize")
output_md = gr.Markdown()
run_btn.click(fn=summarize_ui, inputs=[input_type, input_value, model_choice, out_format], outputs=output_md)
# Launch (Spaces will serve this automatically)
if __name__ == "__main__":
demo.launch()