Spaces:
Sleeping
Sleeping
# app.py | |
# Gradio multi-model summarizer (two-model outputs). Uses only Hugging Face models. | |
# Supports URL or pasted text input, and Paragraph / Points output formats. | |
import os, time, re | |
from typing import List, Optional, Tuple | |
import torch | |
import trafilatura | |
import gradio as gr | |
from transformers import AutoTokenizer, pipeline | |
# ---------------- Helper functions ---------------- | |
def fetch_article(url: str) -> str: | |
downloaded = trafilatura.fetch_url(url) | |
if not downloaded: | |
raise ValueError("Failed to download URL (check link or network).") | |
text = trafilatura.extract(downloaded, include_comments=False, favor_recall=True) | |
if not text: | |
raise ValueError("Could not extract main article text.") | |
return clean_text(text) | |
def clean_text(text: str) -> str: | |
text = re.sub(r'\s+', ' ', text).strip() | |
sentences = [s for s in re.split(r'(?<=[.!?])\s+', text) if len(s) > 3] | |
return " ".join(sentences) | |
def to_bullet_points(summary_text: str) -> str: | |
sentences = re.split(r'(?<=[.!?])\s+', summary_text.strip()) | |
sentences = [s.strip() for s in sentences if len(s.strip()) > 3] | |
return "\n".join([f"- {s}" for s in sentences]) | |
# ---------------- Robust token-level chunking ---------------- | |
def _effective_max_tokens(tokenizer, default: int = 1024) -> int: | |
m = getattr(tokenizer, "model_max_length", None) | |
if m is None: | |
return default | |
try: | |
m = int(m) | |
except Exception: | |
return default | |
if m <= 0 or m > 1_000_000: | |
return default | |
return min(m, default) | |
def chunk_by_tokens_safe(text: str, tokenizer, max_tokens: int, overlap: int = 64, max_chunks: int = 20) -> List[str]: | |
ids = tokenizer.encode(text, add_special_tokens=False, truncation=False) | |
if not ids: | |
return [] | |
safe_max = max_tokens | |
step = max(safe_max - overlap, 1) | |
chunks = [] | |
start = 0 | |
while start < len(ids) and len(chunks) < max_chunks: | |
end = min(start + safe_max, len(ids)) | |
chunk_ids = ids[start:end] | |
if len(chunk_ids) > safe_max: | |
chunk_ids = chunk_ids[:safe_max] | |
chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True)) | |
if end >= len(ids): | |
break | |
start += step | |
return chunks | |
# ---------------- Hugging Face summarizer wrapper ---------------- | |
class HFSummarizer: | |
def __init__(self, model_name: str, device: Optional[int] = None, | |
max_input_tokens: Optional[int] = None, | |
per_chunk_new_tokens: int = 150, | |
per_chunk_min_new_tokens: int = 50, | |
reduce_new_tokens: int = 200, | |
beams: int = 4): | |
if device is None: | |
device = 0 if torch.cuda.is_available() else -1 | |
self.device = device | |
self.model_name = model_name | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) | |
model_max = _effective_max_tokens(self.tokenizer, default=1024) | |
if max_input_tokens is None: | |
self.max_input_tokens = max(128, model_max - 4) | |
else: | |
self.max_input_tokens = min(max_input_tokens, model_max - 4) | |
self.per_chunk_new_tokens = per_chunk_new_tokens | |
self.per_chunk_min_new_tokens = per_chunk_min_new_tokens | |
self.reduce_new_tokens = reduce_new_tokens | |
self.beams = beams | |
# create HF pipeline (lazy enough; will download model the first time) | |
self.pipe = pipeline("summarization", model=self.model_name, tokenizer=self.tokenizer, device=self.device) | |
def _prepare_for_model(self, text: str) -> str: | |
# Some models (T5 family) prefer "summarize: " prefix — pipeline often handles it, | |
# but giving it explicitly improves results for T5. | |
if "t5" in self.model_name.lower(): | |
return "summarize: " + text | |
return text | |
def _summarize_once(self, text: str, max_new: int, min_new: int) -> str: | |
prepared = self._prepare_for_model(text) | |
tok_len = len(self.tokenizer.encode(prepared, add_special_tokens=True, truncation=False)) | |
print(f" -> model '{self.model_name}', tokens (with special tokens): {tok_len}") | |
out = self.pipe(prepared, num_beams=self.beams, do_sample=False, length_penalty=1.0, | |
early_stopping=True, max_new_tokens=max_new, min_new_tokens=min_new, truncation=True)[0]["summary_text"] | |
return out.strip() | |
def summarize(self, text: str) -> str: | |
text = clean_text(text) | |
ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=False) | |
if len(ids) <= self.max_input_tokens: | |
return self._summarize_once(text, self.per_chunk_new_tokens, self.per_chunk_min_new_tokens) | |
chunks = chunk_by_tokens_safe(text, self.tokenizer, max_tokens=self.max_input_tokens, overlap=64, max_chunks=20) | |
if not chunks: | |
return "" | |
partials = [] | |
for c in chunks: | |
partials.append(self._summarize_once(c, self.per_chunk_new_tokens, self.per_chunk_min_new_tokens)) | |
stitched = " ".join(partials) | |
final = self._summarize_once(stitched, self.reduce_new_tokens, min(80, self.reduce_new_tokens // 3)) | |
return final | |
# ---------------- Multi-model coordinator ---------------- | |
class MultiHFSummarizer: | |
def __init__(self, models: List[str] = None): | |
if models is None: | |
models = ["sshleifer/distilbart-cnn-12-6", "facebook/bart-large-cnn"] | |
self.models = models | |
self._instances = {} | |
def _get_inst(self, model_name: str) -> HFSummarizer: | |
if model_name not in self._instances: | |
self._instances[model_name] = HFSummarizer(model_name=model_name) | |
return self._instances[model_name] | |
def summarize_text(self, text: str) -> List[Tuple[str, str]]: | |
results = [] | |
for m in self.models: | |
print(f"\nRunning model: {m}") | |
inst = self._get_inst(m) | |
t0 = time.time() | |
s = inst.summarize(text) | |
elapsed = round(time.time() - t0, 2) | |
print(f"Model {m} finished in {elapsed}s") | |
results.append((m, s)) | |
return results | |
def summarize_url(self, url: str) -> List[Tuple[str, str]]: | |
text = fetch_article(url) | |
return self.summarize_text(text) | |
# ---------------- Gradio UI logic ---------------- | |
MODEL_OPTIONS = { | |
"DistilBART + BART-large": ["sshleifer/distilbart-cnn-12-6", "facebook/bart-large-cnn"], | |
"DistilBART + Pegasus": ["sshleifer/distilbart-cnn-12-6", "google/pegasus-cnn_dailymail"], | |
"DistilBART + T5-small": ["sshleifer/distilbart-cnn-12-6", "t5-small"], | |
} | |
def summarize_ui(input_type: str, input_value: str, model_choice: str, out_format: str): | |
# get text | |
try: | |
if input_type == "URL": | |
text = fetch_article(input_value) | |
else: | |
text = input_value | |
if not text or len(text.strip()) == 0: | |
return "No text found. Please paste text or check the URL." | |
except Exception as e: | |
return f"Error fetching input: {e}" | |
models = MODEL_OPTIONS.get(model_choice, MODEL_OPTIONS["DistilBART + BART-large"]) | |
# Warn if user selected a heavy model on CPU | |
warning = "" | |
if (not torch.cuda.is_available()) and any("bart-large" in m for m in models): | |
warning = ("**Warning:** You're running on CPU. `facebook/bart-large-cnn` is large and may run out of memory " | |
"or be slow. Consider choosing a lighter pair (DistilBART + T5-small) or request GPU in Space settings.\n\n") | |
coordinator = MultiHFSummarizer(models=models) | |
outputs = coordinator.summarize_text(text) | |
md = warning | |
for model_name, summary in outputs: | |
md += f"### {model_name}\n\n" | |
if out_format == "Points": | |
md += to_bullet_points(summary) + "\n\n" | |
else: | |
md += summary + "\n\n" | |
return md | |
# ---------------- Build Gradio interface ---------------- | |
with gr.Blocks(title="Multi-Model Summarizer") as demo: | |
gr.Markdown("# Multi-Model Summarizer (Hugging Face models)\nChoose input, model pair, and output format (paragraph or points).") | |
with gr.Row(): | |
input_type = gr.Radio(["URL", "Text"], value="URL", label="Input type") | |
model_choice = gr.Dropdown(list(MODEL_OPTIONS.keys()), value="DistilBART + BART-large", label="Model pair") | |
out_format = gr.Dropdown(["Paragraph", "Points"], value="Paragraph", label="Output format") | |
input_value = gr.Textbox(lines=6, placeholder="Paste article URL or text here...") | |
run_btn = gr.Button("Summarize") | |
output_md = gr.Markdown() | |
run_btn.click(fn=summarize_ui, inputs=[input_type, input_value, model_choice, out_format], outputs=output_md) | |
# Launch (Spaces will serve this automatically) | |
if __name__ == "__main__": | |
demo.launch() | |