# app.py # Gradio multi-model summarizer (two-model outputs). Uses only Hugging Face models. # Supports URL or pasted text input, and Paragraph / Points output formats. import os, time, re from typing import List, Optional, Tuple import torch import trafilatura import gradio as gr from transformers import AutoTokenizer, pipeline # ---------------- Helper functions ---------------- def fetch_article(url: str) -> str: downloaded = trafilatura.fetch_url(url) if not downloaded: raise ValueError("Failed to download URL (check link or network).") text = trafilatura.extract(downloaded, include_comments=False, favor_recall=True) if not text: raise ValueError("Could not extract main article text.") return clean_text(text) def clean_text(text: str) -> str: text = re.sub(r'\s+', ' ', text).strip() sentences = [s for s in re.split(r'(?<=[.!?])\s+', text) if len(s) > 3] return " ".join(sentences) def to_bullet_points(summary_text: str) -> str: sentences = re.split(r'(?<=[.!?])\s+', summary_text.strip()) sentences = [s.strip() for s in sentences if len(s.strip()) > 3] return "\n".join([f"- {s}" for s in sentences]) # ---------------- Robust token-level chunking ---------------- def _effective_max_tokens(tokenizer, default: int = 1024) -> int: m = getattr(tokenizer, "model_max_length", None) if m is None: return default try: m = int(m) except Exception: return default if m <= 0 or m > 1_000_000: return default return min(m, default) def chunk_by_tokens_safe(text: str, tokenizer, max_tokens: int, overlap: int = 64, max_chunks: int = 20) -> List[str]: ids = tokenizer.encode(text, add_special_tokens=False, truncation=False) if not ids: return [] safe_max = max_tokens step = max(safe_max - overlap, 1) chunks = [] start = 0 while start < len(ids) and len(chunks) < max_chunks: end = min(start + safe_max, len(ids)) chunk_ids = ids[start:end] if len(chunk_ids) > safe_max: chunk_ids = chunk_ids[:safe_max] chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True)) if end >= len(ids): break start += step return chunks # ---------------- Hugging Face summarizer wrapper ---------------- class HFSummarizer: def __init__(self, model_name: str, device: Optional[int] = None, max_input_tokens: Optional[int] = None, per_chunk_new_tokens: int = 150, per_chunk_min_new_tokens: int = 50, reduce_new_tokens: int = 200, beams: int = 4): if device is None: device = 0 if torch.cuda.is_available() else -1 self.device = device self.model_name = model_name self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) model_max = _effective_max_tokens(self.tokenizer, default=1024) if max_input_tokens is None: self.max_input_tokens = max(128, model_max - 4) else: self.max_input_tokens = min(max_input_tokens, model_max - 4) self.per_chunk_new_tokens = per_chunk_new_tokens self.per_chunk_min_new_tokens = per_chunk_min_new_tokens self.reduce_new_tokens = reduce_new_tokens self.beams = beams # create HF pipeline (lazy enough; will download model the first time) self.pipe = pipeline("summarization", model=self.model_name, tokenizer=self.tokenizer, device=self.device) def _prepare_for_model(self, text: str) -> str: # Some models (T5 family) prefer "summarize: " prefix — pipeline often handles it, # but giving it explicitly improves results for T5. if "t5" in self.model_name.lower(): return "summarize: " + text return text def _summarize_once(self, text: str, max_new: int, min_new: int) -> str: prepared = self._prepare_for_model(text) tok_len = len(self.tokenizer.encode(prepared, add_special_tokens=True, truncation=False)) print(f" -> model '{self.model_name}', tokens (with special tokens): {tok_len}") out = self.pipe(prepared, num_beams=self.beams, do_sample=False, length_penalty=1.0, early_stopping=True, max_new_tokens=max_new, min_new_tokens=min_new, truncation=True)[0]["summary_text"] return out.strip() def summarize(self, text: str) -> str: text = clean_text(text) ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=False) if len(ids) <= self.max_input_tokens: return self._summarize_once(text, self.per_chunk_new_tokens, self.per_chunk_min_new_tokens) chunks = chunk_by_tokens_safe(text, self.tokenizer, max_tokens=self.max_input_tokens, overlap=64, max_chunks=20) if not chunks: return "" partials = [] for c in chunks: partials.append(self._summarize_once(c, self.per_chunk_new_tokens, self.per_chunk_min_new_tokens)) stitched = " ".join(partials) final = self._summarize_once(stitched, self.reduce_new_tokens, min(80, self.reduce_new_tokens // 3)) return final # ---------------- Multi-model coordinator ---------------- class MultiHFSummarizer: def __init__(self, models: List[str] = None): if models is None: models = ["sshleifer/distilbart-cnn-12-6", "facebook/bart-large-cnn"] self.models = models self._instances = {} def _get_inst(self, model_name: str) -> HFSummarizer: if model_name not in self._instances: self._instances[model_name] = HFSummarizer(model_name=model_name) return self._instances[model_name] def summarize_text(self, text: str) -> List[Tuple[str, str]]: results = [] for m in self.models: print(f"\nRunning model: {m}") inst = self._get_inst(m) t0 = time.time() s = inst.summarize(text) elapsed = round(time.time() - t0, 2) print(f"Model {m} finished in {elapsed}s") results.append((m, s)) return results def summarize_url(self, url: str) -> List[Tuple[str, str]]: text = fetch_article(url) return self.summarize_text(text) # ---------------- Gradio UI logic ---------------- MODEL_OPTIONS = { "DistilBART + BART-large": ["sshleifer/distilbart-cnn-12-6", "facebook/bart-large-cnn"], "DistilBART + Pegasus": ["sshleifer/distilbart-cnn-12-6", "google/pegasus-cnn_dailymail"], "DistilBART + T5-small": ["sshleifer/distilbart-cnn-12-6", "t5-small"], } def summarize_ui(input_type: str, input_value: str, model_choice: str, out_format: str): # get text try: if input_type == "URL": text = fetch_article(input_value) else: text = input_value if not text or len(text.strip()) == 0: return "No text found. Please paste text or check the URL." except Exception as e: return f"Error fetching input: {e}" models = MODEL_OPTIONS.get(model_choice, MODEL_OPTIONS["DistilBART + BART-large"]) # Warn if user selected a heavy model on CPU warning = "" if (not torch.cuda.is_available()) and any("bart-large" in m for m in models): warning = ("**Warning:** You're running on CPU. `facebook/bart-large-cnn` is large and may run out of memory " "or be slow. Consider choosing a lighter pair (DistilBART + T5-small) or request GPU in Space settings.\n\n") coordinator = MultiHFSummarizer(models=models) outputs = coordinator.summarize_text(text) md = warning for model_name, summary in outputs: md += f"### {model_name}\n\n" if out_format == "Points": md += to_bullet_points(summary) + "\n\n" else: md += summary + "\n\n" return md # ---------------- Build Gradio interface ---------------- with gr.Blocks(title="Multi-Model Summarizer") as demo: gr.Markdown("# Multi-Model Summarizer (Hugging Face models)\nChoose input, model pair, and output format (paragraph or points).") with gr.Row(): input_type = gr.Radio(["URL", "Text"], value="URL", label="Input type") model_choice = gr.Dropdown(list(MODEL_OPTIONS.keys()), value="DistilBART + BART-large", label="Model pair") out_format = gr.Dropdown(["Paragraph", "Points"], value="Paragraph", label="Output format") input_value = gr.Textbox(lines=6, placeholder="Paste article URL or text here...") run_btn = gr.Button("Summarize") output_md = gr.Markdown() run_btn.click(fn=summarize_ui, inputs=[input_type, input_value, model_choice, out_format], outputs=output_md) # Launch (Spaces will serve this automatically) if __name__ == "__main__": demo.launch()