Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# main.py
|
| 2 |
import os
|
| 3 |
os.environ['HF_HOME'] = '/tmp'
|
| 4 |
|
|
@@ -22,27 +22,26 @@ logging.basicConfig(level=logging.INFO)
|
|
| 22 |
logger = logging.getLogger("summarizer")
|
| 23 |
|
| 24 |
# -------------------------
|
| 25 |
-
# Device selection (CPU
|
| 26 |
# -------------------------
|
| 27 |
-
# The user asked for CPU — force device -1. If you later enable GPU, set DEVICE accordingly.
|
| 28 |
USE_GPU = False
|
| 29 |
DEVICE = -1
|
| 30 |
-
logger.info("Startup: forcing CPU usage for
|
| 31 |
|
| 32 |
# -------------------------
|
| 33 |
# Model names and caches
|
| 34 |
# -------------------------
|
| 35 |
PEGASUS_MODEL = "google/pegasus-large"
|
| 36 |
LED_MODEL = "allenai/led-large-16384"
|
| 37 |
-
|
| 38 |
-
PARAM_MODEL = "google/flan-t5-small"
|
| 39 |
|
| 40 |
_SUMMARIZER_CACHE: Dict[str, Any] = {}
|
| 41 |
_PARAM_GENERATOR = None
|
| 42 |
_PREFERRED_SUMMARIZER_KEY: Optional[str] = None
|
| 43 |
|
| 44 |
# -------------------------
|
| 45 |
-
# Utilities
|
| 46 |
# -------------------------
|
| 47 |
_STOPWORDS = {
|
| 48 |
"the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
|
|
@@ -68,7 +67,6 @@ def extractive_prefilter(text: str, top_k: int = 6) -> str:
|
|
| 68 |
return " ".join(chosen)
|
| 69 |
|
| 70 |
def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]:
|
| 71 |
-
# small chunks to keep CPU generation per-call bounded
|
| 72 |
n = len(text)
|
| 73 |
if n <= max_chars:
|
| 74 |
return [text]
|
|
@@ -89,97 +87,98 @@ def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) ->
|
|
| 89 |
start = end
|
| 90 |
return [p for p in parts if p]
|
| 91 |
|
| 92 |
-
def _first_int_from_text(s: str, fallback: Optional[int] = None) -> Optional[int]:
|
| 93 |
-
m = re.search(r"\d{1,4}", s)
|
| 94 |
-
return int(m.group()) if m else fallback
|
| 95 |
-
|
| 96 |
# -------------------------
|
| 97 |
-
#
|
| 98 |
# -------------------------
|
| 99 |
-
def
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
_PARAM_GENERATOR = pipeline("text2text-generation", model=p_mod, tokenizer=p_tok, device=DEVICE)
|
| 107 |
-
logger.info("Param-generator loaded as text2text-generation.")
|
| 108 |
-
except Exception as e:
|
| 109 |
-
logger.exception("Param-generator failed to load as text2text: %s", e)
|
| 110 |
-
_PARAM_GENERATOR = None
|
| 111 |
-
|
| 112 |
-
def preload_models_at_startup():
|
| 113 |
-
global _PREFERRED_SUMMARIZER_KEY
|
| 114 |
-
# Preload Pegasus and LED (best-effort). If fail, load fallback.
|
| 115 |
-
logger.info("Preloading summarizer models (Pegasus, LED, fallback)...")
|
| 116 |
-
p = safe_load_pipeline(PEGASUS_MODEL)
|
| 117 |
-
if p:
|
| 118 |
-
_SUMMARIZER_CACHE["pegasus"] = p
|
| 119 |
-
_PREFERRED_SUMMARIZER_KEY = "pegasus"
|
| 120 |
-
else:
|
| 121 |
-
logger.warning("Pegasus failed to load on CPU; will still attempt LED and fallback.")
|
| 122 |
-
|
| 123 |
-
l = safe_load_pipeline(LED_MODEL)
|
| 124 |
-
if l:
|
| 125 |
-
_SUMMARIZER_CACHE["led"] = l
|
| 126 |
-
|
| 127 |
-
# Always ensure fallback available
|
| 128 |
-
f = safe_load_pipeline(FALLBACK_MODEL)
|
| 129 |
-
if f:
|
| 130 |
-
_SUMMARIZER_CACHE["distilbart"] = f
|
| 131 |
-
if not _PREFERRED_SUMMARIZER_KEY:
|
| 132 |
-
_PREFERRED_SUMMARIZER_KEY = "distilbart"
|
| 133 |
-
else:
|
| 134 |
-
logger.critical("Fallback model failed to load. The app will not be able to summarize.")
|
| 135 |
-
|
| 136 |
-
# Load param generator if possible (small model)
|
| 137 |
try:
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# -------------------------
|
| 154 |
-
#
|
| 155 |
# -------------------------
|
| 156 |
-
def
|
| 157 |
"""
|
| 158 |
-
|
| 159 |
-
|
| 160 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
model_name = getattr(pipe.model.config, "name_or_path", "") or ""
|
| 162 |
-
|
| 163 |
-
is_led = "led" in model_name_lower or "longformer" in model_name_lower
|
| 164 |
|
| 165 |
-
#
|
| 166 |
fast_cfg = {
|
| 167 |
"max_new_tokens": 64 if short_target else (120 if not is_led else 240),
|
| 168 |
"do_sample": True,
|
| 169 |
-
"top_p": 0.
|
| 170 |
"temperature": 0.85,
|
| 171 |
"num_beams": 1,
|
| 172 |
"early_stopping": True,
|
| 173 |
"no_repeat_ngram_size": 3,
|
| 174 |
}
|
| 175 |
-
|
| 176 |
try:
|
| 177 |
-
|
| 178 |
-
return out
|
| 179 |
except Exception as e:
|
| 180 |
-
logger.warning("Fast pass failed: %s
|
| 181 |
|
| 182 |
-
# QUALITY beam pass (more expensive)
|
| 183 |
quality_cfg = {
|
| 184 |
"max_new_tokens": 140 if not is_led else 320,
|
| 185 |
"do_sample": False,
|
|
@@ -187,121 +186,99 @@ def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) ->
|
|
| 187 |
"early_stopping": True,
|
| 188 |
"no_repeat_ngram_size": 3,
|
| 189 |
}
|
| 190 |
-
|
| 191 |
try:
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
logger.exception("Quality pass failed: %s", e2)
|
| 196 |
|
| 197 |
-
# extractive
|
| 198 |
try:
|
| 199 |
return extractive_prefilter(text_prompt, top_k=3)
|
| 200 |
except Exception:
|
| 201 |
-
return "Summarization failed;
|
| 202 |
|
| 203 |
# -------------------------
|
| 204 |
-
# Param generator (AI decision)
|
| 205 |
# -------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
def generate_summarization_config(text: str) -> Dict[str, Any]:
|
| 207 |
-
"""
|
| 208 |
-
Uses the text2text param-generator to output a JSON config.
|
| 209 |
-
If the generator fails or returns something noisy (e.g., echoes the input),
|
| 210 |
-
fall back to a safe heuristic.
|
| 211 |
-
"""
|
| 212 |
defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
|
| 213 |
-
|
| 214 |
-
|
| 215 |
words = len(text.split())
|
| 216 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 217 |
mn, mx = defaults[length]
|
| 218 |
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 219 |
|
| 220 |
-
pg = _PARAM_GENERATOR
|
| 221 |
-
if pg is None:
|
| 222 |
-
logger.info("Param-generator not available; using fallback heuristic.")
|
| 223 |
-
return fallback()
|
| 224 |
-
|
| 225 |
prompt = (
|
| 226 |
-
"Recommend summarization settings for this text. Answer ONLY with JSON
|
| 227 |
'{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
|
| 228 |
"Text:\n'''"
|
| 229 |
+ text[:3000] + "'''"
|
| 230 |
)
|
| 231 |
-
|
| 232 |
try:
|
| 233 |
-
out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1
|
| 234 |
-
|
| 235 |
-
out = out_item.get("generated_text") or out_item.get("summary_text") or out_item.get("text") or ""
|
| 236 |
out = (out or "").strip()
|
| 237 |
-
|
| 238 |
-
# COMMON FAILURE MODE: the model just echoes the input — reject that
|
| 239 |
-
# If output contains a long substring of the input, treat as invalid.
|
| 240 |
if not out:
|
| 241 |
raise ValueError("Empty param-generator output")
|
| 242 |
-
#
|
| 243 |
input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200])
|
| 244 |
out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200])
|
| 245 |
-
if len(input_words)
|
| 246 |
-
logger.warning("Param-generator appears to echo input;
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
# Find JSON object in output
|
| 250 |
jmatch = re.search(r"\{.*\}", out, re.DOTALL)
|
| 251 |
if jmatch:
|
| 252 |
raw = jmatch.group().replace("'", '"')
|
| 253 |
cfg = json.loads(raw)
|
| 254 |
else:
|
| 255 |
-
# attempt to parse line with key:value pairs (tolerant)
|
| 256 |
cfg = None
|
| 257 |
|
| 258 |
if not cfg or not isinstance(cfg, dict):
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
min_w = cfg.get("min_words") or cfg.get("min_length") or cfg.get("min")
|
| 265 |
-
max_w = cfg.get("max_words") or cfg.get("max_length") or cfg.get("max")
|
| 266 |
-
|
| 267 |
-
if length not in ("short","medium","long"):
|
| 268 |
-
words = len(text.split())
|
| 269 |
-
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 270 |
-
if tone not in ("neutral","formal","casual","bullet"):
|
| 271 |
-
tone = "neutral"
|
| 272 |
-
|
| 273 |
-
defaults_min, defaults_max = defaults.get(length, (50,130))
|
| 274 |
-
try:
|
| 275 |
-
mn = int(min_w) if min_w is not None else defaults_min
|
| 276 |
-
mx = int(max_w) if max_w is not None else defaults_max
|
| 277 |
-
except Exception:
|
| 278 |
-
mn, mx = defaults_min, defaults_max
|
| 279 |
-
|
| 280 |
mn = max(5, min(mn, 2000))
|
| 281 |
mx = max(mn + 5, min(mx, 4000))
|
| 282 |
-
logger.info("Param-generator suggested length=%s tone=%s min=%s max=%s", length, tone, mn, mx)
|
| 283 |
return {"length": length, "min_length": mn, "max_length": mx, "tone": tone}
|
| 284 |
-
|
| 285 |
except Exception as e:
|
| 286 |
-
logger.exception("Param-generator failed
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
# -------------------------
|
| 290 |
-
#
|
| 291 |
# -------------------------
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
REFINE_TIMEOUT_SECONDS = 60 # final refinement timeout
|
| 296 |
-
MAX_TOTAL_SECONDS = 180 # overall safety cap for a request
|
| 297 |
-
|
| 298 |
-
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
| 299 |
|
| 300 |
def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> List[str]:
|
| 301 |
-
"""
|
| 302 |
-
Submit chunk summarization tasks to threadpool; apply per-chunk timeout.
|
| 303 |
-
If a chunk times out or fails, use extractive_prefilter fallback.
|
| 304 |
-
"""
|
| 305 |
futures = {}
|
| 306 |
results = [None] * len(chunks)
|
| 307 |
for idx, chunk in enumerate(chunks):
|
|
@@ -310,26 +287,24 @@ def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> Lis
|
|
| 310 |
futures[fut] = idx
|
| 311 |
|
| 312 |
start = time.time()
|
| 313 |
-
for fut in as_completed(futures
|
| 314 |
idx = futures[fut]
|
| 315 |
try:
|
| 316 |
remaining = max(0.1, CHUNK_TIMEOUT_SECONDS - (time.time() - start))
|
| 317 |
-
# bound waiting for each future to CHUNK_TIMEOUT_SECONDS
|
| 318 |
results[idx] = fut.result(timeout=remaining)
|
| 319 |
except TimeoutError:
|
| 320 |
logger.warning("Chunk %d timed out; using extractive fallback.", idx)
|
| 321 |
results[idx] = extractive_prefilter(chunks[idx], top_k=3)
|
| 322 |
except Exception as e:
|
| 323 |
-
logger.exception("Chunk %d failed: %s;
|
| 324 |
results[idx] = extractive_prefilter(chunks[idx], top_k=3)
|
| 325 |
-
# ensure no None
|
| 326 |
for i, r in enumerate(results):
|
| 327 |
if not r:
|
| 328 |
results[i] = extractive_prefilter(chunks[i], top_k=3)
|
| 329 |
return results
|
| 330 |
|
| 331 |
# -------------------------
|
| 332 |
-
# Prompt
|
| 333 |
# -------------------------
|
| 334 |
def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int] = None) -> str:
|
| 335 |
tone = (tone or "neutral").lower()
|
|
@@ -337,13 +312,13 @@ def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int]
|
|
| 337 |
instr = "Produce concise bullet points. Each bullet <= 20 words. No extra commentary."
|
| 338 |
elif tone == "short":
|
| 339 |
ts = target_sentences or 1
|
| 340 |
-
instr = f"Summarize in {ts} sentence{'s' if ts>1 else ''}. Be abstractive
|
| 341 |
elif tone == "formal":
|
| 342 |
-
instr = "Summarize in a formal professional tone
|
| 343 |
elif tone == "casual":
|
| 344 |
-
instr = "Summarize in a casual tone
|
| 345 |
elif tone == "long":
|
| 346 |
-
instr = "Provide
|
| 347 |
else:
|
| 348 |
instr = "Summarize in 2-3 clear sentences."
|
| 349 |
instr += " Do not repeat information. Prefer rephrasing."
|
|
@@ -354,30 +329,51 @@ def refine_combined(pipe, summaries_list: List[str], tone: str, final_target_sen
|
|
| 354 |
if len(combined.split()) > 1200:
|
| 355 |
combined = extractive_prefilter(combined, top_k=20)
|
| 356 |
prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
|
| 357 |
-
# run refine in executor to apply timeout
|
| 358 |
fut = executor.submit(summarize_with_model, pipe, prompt, short_target=False)
|
| 359 |
try:
|
| 360 |
return fut.result(timeout=REFINE_TIMEOUT_SECONDS)
|
| 361 |
except TimeoutError:
|
| 362 |
-
logger.warning("Refine
|
| 363 |
return " ".join(summaries_list[:6])
|
| 364 |
except Exception as e:
|
| 365 |
-
logger.exception("Refine
|
| 366 |
return " ".join(summaries_list[:6])
|
| 367 |
|
| 368 |
# -------------------------
|
| 369 |
-
#
|
| 370 |
# -------------------------
|
| 371 |
@app.route("/", methods=["GET"])
|
| 372 |
def home():
|
| 373 |
try:
|
| 374 |
return render_template("index.html")
|
| 375 |
except Exception:
|
| 376 |
-
return "Summarizer (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
@app.route("/summarize", methods=["POST"])
|
| 379 |
def summarize_route():
|
| 380 |
-
|
| 381 |
data = request.get_json(force=True) or {}
|
| 382 |
text = (data.get("text") or "").strip()[:90000]
|
| 383 |
user_model_pref = (data.get("model") or "auto").lower()
|
|
@@ -387,16 +383,16 @@ def summarize_route():
|
|
| 387 |
if not text or len(text.split()) < 5:
|
| 388 |
return jsonify({"error": "Input too short."}), 400
|
| 389 |
|
| 390 |
-
#
|
| 391 |
-
if requested_length in ("auto",
|
| 392 |
cfg = generate_summarization_config(text)
|
| 393 |
-
length_choice = cfg.get("length",
|
| 394 |
-
tone_choice = cfg.get("tone",
|
| 395 |
else:
|
| 396 |
length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
|
| 397 |
tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
|
| 398 |
|
| 399 |
-
#
|
| 400 |
words_len = len(text.split())
|
| 401 |
prefer_led = False
|
| 402 |
if user_model_pref == "led":
|
|
@@ -406,84 +402,65 @@ def summarize_route():
|
|
| 406 |
else:
|
| 407 |
if length_choice == "long" or words_len > 3000:
|
| 408 |
prefer_led = True
|
| 409 |
-
else:
|
| 410 |
-
prefer_led = False
|
| 411 |
|
| 412 |
-
# Choose actual pipe
|
| 413 |
model_key = "led" if prefer_led else (_PREFERRED_SUMMARIZER_KEY or "distilbart")
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
_SUMMARIZER_CACHE[model_key] = pipe
|
| 420 |
-
except Exception as e:
|
| 421 |
-
logger.exception("On-demand load failed: %s", e)
|
| 422 |
-
|
| 423 |
-
# Ensure we have at least a fallback
|
| 424 |
-
if model_key not in _SUMMARIZER_CACHE:
|
| 425 |
model_key = "distilbart"
|
| 426 |
-
summarizer_pipe = _SUMMARIZER_CACHE[model_key]
|
| 427 |
|
| 428 |
-
#
|
| 429 |
if model_key != "led" and words_len > 2500:
|
| 430 |
text_for_chunks = extractive_prefilter(text, top_k=40)
|
| 431 |
else:
|
| 432 |
text_for_chunks = text
|
| 433 |
|
| 434 |
-
#
|
| 435 |
if model_key == "led":
|
| 436 |
chunk_max = 6000
|
| 437 |
overlap = 400
|
| 438 |
else:
|
| 439 |
-
chunk_max = 800
|
| 440 |
overlap = 120
|
| 441 |
|
| 442 |
chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max, overlap=overlap)
|
|
|
|
| 443 |
|
| 444 |
-
#
|
| 445 |
-
chunk_target_sentences = 1 if length_choice == "short" else 2
|
| 446 |
-
t_chunks_start = time.time()
|
| 447 |
try:
|
| 448 |
-
chunk_summaries = summarize_chunks_parallel(summarizer_pipe, chunks,
|
| 449 |
except Exception as e:
|
| 450 |
logger.exception("Chunk summarization orchestration failed: %s", e)
|
| 451 |
-
# fallback: simple extractive split
|
| 452 |
chunk_summaries = [extractive_prefilter(c, top_k=3) for c in chunks]
|
| 453 |
-
t_chunks_end = time.time()
|
| 454 |
|
| 455 |
-
#
|
| 456 |
refine_pipe = _SUMMARIZER_CACHE.get("pegasus") or summarizer_pipe
|
| 457 |
-
|
| 458 |
-
final = refine_combined(refine_pipe, chunk_summaries, tone_choice,
|
| 459 |
|
| 460 |
-
#
|
| 461 |
if tone_choice == "bullet":
|
| 462 |
parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
|
| 463 |
bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
|
| 464 |
final = "\n".join(bullets[:20])
|
| 465 |
|
| 466 |
-
elapsed = time.time() -
|
| 467 |
meta = {
|
| 468 |
"length_choice": length_choice,
|
| 469 |
"tone": tone_choice,
|
| 470 |
-
"model_requested": user_model_pref,
|
| 471 |
"model_used": model_key,
|
| 472 |
"chunks": len(chunks),
|
| 473 |
"input_words": words_len,
|
| 474 |
"time_seconds": round(elapsed, 2),
|
| 475 |
-
"
|
| 476 |
-
"device": "cpu",
|
| 477 |
-
"workers_threads": MAX_WORKERS,
|
| 478 |
-
"per_chunk_timeout": CHUNK_TIMEOUT_SECONDS,
|
| 479 |
-
"refine_timeout": REFINE_TIMEOUT_SECONDS
|
| 480 |
}
|
| 481 |
-
|
| 482 |
return jsonify({"summary": final, "meta": meta})
|
| 483 |
|
| 484 |
# -------------------------
|
| 485 |
-
#
|
| 486 |
# -------------------------
|
| 487 |
if __name__ == "__main__":
|
| 488 |
-
# For
|
| 489 |
app.run(host="0.0.0.0", port=7860, debug=False)
|
|
|
|
| 1 |
+
# main.py (REPLACE your existing file with this)
|
| 2 |
import os
|
| 3 |
os.environ['HF_HOME'] = '/tmp'
|
| 4 |
|
|
|
|
| 22 |
logger = logging.getLogger("summarizer")
|
| 23 |
|
| 24 |
# -------------------------
|
| 25 |
+
# Device selection (CPU by default)
|
| 26 |
# -------------------------
|
|
|
|
| 27 |
USE_GPU = False
|
| 28 |
DEVICE = -1
|
| 29 |
+
logger.info("Startup: forcing CPU usage for models (DEVICE=%s)", DEVICE)
|
| 30 |
|
| 31 |
# -------------------------
|
| 32 |
# Model names and caches
|
| 33 |
# -------------------------
|
| 34 |
PEGASUS_MODEL = "google/pegasus-large"
|
| 35 |
LED_MODEL = "allenai/led-large-16384"
|
| 36 |
+
DISTILBART_MODEL = "sshleifer/distilbart-cnn-12-6"
|
| 37 |
+
PARAM_MODEL = "google/flan-t5-small"
|
| 38 |
|
| 39 |
_SUMMARIZER_CACHE: Dict[str, Any] = {}
|
| 40 |
_PARAM_GENERATOR = None
|
| 41 |
_PREFERRED_SUMMARIZER_KEY: Optional[str] = None
|
| 42 |
|
| 43 |
# -------------------------
|
| 44 |
+
# Utilities: chunking, extractive fallback
|
| 45 |
# -------------------------
|
| 46 |
_STOPWORDS = {
|
| 47 |
"the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
|
|
|
|
| 67 |
return " ".join(chosen)
|
| 68 |
|
| 69 |
def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]:
|
|
|
|
| 70 |
n = len(text)
|
| 71 |
if n <= max_chars:
|
| 72 |
return [text]
|
|
|
|
| 87 |
start = end
|
| 88 |
return [p for p in parts if p]
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
# -------------------------
|
| 91 |
+
# safe loader (defined before any calls)
|
| 92 |
# -------------------------
|
| 93 |
+
def safe_load_pipeline(model_name: str):
|
| 94 |
+
"""
|
| 95 |
+
Try to load a summarization pipeline robustly:
|
| 96 |
+
- try fast tokenizer first
|
| 97 |
+
- if that fails, try use_fast=False
|
| 98 |
+
- return pipeline or None if both fail
|
| 99 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
+
logger.info("Loading tokenizer/model for %s (fast)...", model_name)
|
| 102 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
| 103 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 104 |
+
pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
|
| 105 |
+
logger.info("Loaded %s (fast tokenizer)", model_name)
|
| 106 |
+
return pipe
|
| 107 |
+
except Exception as e_fast:
|
| 108 |
+
logger.warning("Fast tokenizer failed for %s: %s. Trying slow tokenizer...", model_name, e_fast)
|
| 109 |
+
try:
|
| 110 |
+
tok = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
| 111 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 112 |
+
pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
|
| 113 |
+
logger.info("Loaded %s (slow tokenizer)", model_name)
|
| 114 |
+
return pipe
|
| 115 |
+
except Exception as e_slow:
|
| 116 |
+
logger.exception("Slow tokenizer failed for %s: %s", model_name, e_slow)
|
| 117 |
+
return None
|
| 118 |
|
| 119 |
# -------------------------
|
| 120 |
+
# get_summarizer: lazy load + cache + fallback
|
| 121 |
# -------------------------
|
| 122 |
+
def get_summarizer(key: str):
|
| 123 |
"""
|
| 124 |
+
key: 'pegasus'|'led'|'distilbart'|'auto'
|
| 125 |
+
returns a pipeline (cached), or raises RuntimeError if no pipeline can be loaded.
|
| 126 |
"""
|
| 127 |
+
key = (key or "auto").lower()
|
| 128 |
+
if key == "auto":
|
| 129 |
+
key = _PREFERRED_SUMMARIZER_KEY or "distilbart"
|
| 130 |
+
|
| 131 |
+
# direct mapping
|
| 132 |
+
model_name = {
|
| 133 |
+
"pegasus": PEGASUS_MODEL,
|
| 134 |
+
"led": LED_MODEL,
|
| 135 |
+
"distilbart": DISTILBART_MODEL
|
| 136 |
+
}.get(key, DISTILBART_MODEL)
|
| 137 |
+
|
| 138 |
+
if key in _SUMMARIZER_CACHE:
|
| 139 |
+
return _SUMMARIZER_CACHE[key]
|
| 140 |
+
|
| 141 |
+
# try to load
|
| 142 |
+
logger.info("Attempting to lazy-load summarizer '%s' -> %s", key, model_name)
|
| 143 |
+
pipe = safe_load_pipeline(model_name)
|
| 144 |
+
if pipe:
|
| 145 |
+
_SUMMARIZER_CACHE[key] = pipe
|
| 146 |
+
return pipe
|
| 147 |
+
|
| 148 |
+
# fallback attempts
|
| 149 |
+
logger.warning("Failed to load %s. Trying distilbart fallback.", key)
|
| 150 |
+
if "distilbart" in _SUMMARIZER_CACHE:
|
| 151 |
+
return _SUMMARIZER_CACHE["distilbart"]
|
| 152 |
+
fb = safe_load_pipeline(DISTILBART_MODEL)
|
| 153 |
+
if fb:
|
| 154 |
+
_SUMMARIZER_CACHE["distilbart"] = fb
|
| 155 |
+
return fb
|
| 156 |
+
|
| 157 |
+
# nothing works
|
| 158 |
+
raise RuntimeError("No summarizer available. Install required libraries and/or choose smaller model.")
|
| 159 |
+
|
| 160 |
+
# -------------------------
|
| 161 |
+
# Generation strategy + small helpers
|
| 162 |
+
# -------------------------
|
| 163 |
+
def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) -> str:
|
| 164 |
model_name = getattr(pipe.model.config, "name_or_path", "") or ""
|
| 165 |
+
is_led = "led" in model_name.lower() or "longformer" in model_name.lower()
|
|
|
|
| 166 |
|
| 167 |
+
# Fast sampling pass
|
| 168 |
fast_cfg = {
|
| 169 |
"max_new_tokens": 64 if short_target else (120 if not is_led else 240),
|
| 170 |
"do_sample": True,
|
| 171 |
+
"top_p": 0.92,
|
| 172 |
"temperature": 0.85,
|
| 173 |
"num_beams": 1,
|
| 174 |
"early_stopping": True,
|
| 175 |
"no_repeat_ngram_size": 3,
|
| 176 |
}
|
|
|
|
| 177 |
try:
|
| 178 |
+
return pipe(text_prompt, **fast_cfg)[0].get("summary_text","").strip()
|
|
|
|
| 179 |
except Exception as e:
|
| 180 |
+
logger.warning("Fast pass failed: %s, trying quality pass...", e)
|
| 181 |
|
|
|
|
| 182 |
quality_cfg = {
|
| 183 |
"max_new_tokens": 140 if not is_led else 320,
|
| 184 |
"do_sample": False,
|
|
|
|
| 186 |
"early_stopping": True,
|
| 187 |
"no_repeat_ngram_size": 3,
|
| 188 |
}
|
|
|
|
| 189 |
try:
|
| 190 |
+
return pipe(text_prompt, **quality_cfg)[0].get("summary_text","").strip()
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.exception("Quality pass failed: %s", e)
|
|
|
|
| 193 |
|
| 194 |
+
# fallback extractive
|
| 195 |
try:
|
| 196 |
return extractive_prefilter(text_prompt, top_k=3)
|
| 197 |
except Exception:
|
| 198 |
+
return "Summarization failed; try shorter input."
|
| 199 |
|
| 200 |
# -------------------------
|
| 201 |
+
# Param generator (AI decision) - lazy loader
|
| 202 |
# -------------------------
|
| 203 |
+
def get_param_generator():
|
| 204 |
+
global _PARAM_GENERATOR
|
| 205 |
+
if _PARAM_GENERATOR is not None:
|
| 206 |
+
return _PARAM_GENERATOR
|
| 207 |
+
# try to load text2text pipeline for PARAM_MODEL
|
| 208 |
+
try:
|
| 209 |
+
logger.info("Loading param-generator (text2text) lazily: %s", PARAM_MODEL)
|
| 210 |
+
tok = AutoTokenizer.from_pretrained(PARAM_MODEL)
|
| 211 |
+
mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
|
| 212 |
+
_PARAM_GENERATOR = pipeline("text2text-generation", model=mod, tokenizer=tok, device=DEVICE)
|
| 213 |
+
logger.info("Param-generator loaded.")
|
| 214 |
+
return _PARAM_GENERATOR
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.exception("Param-generator lazy load failed: %s", e)
|
| 217 |
+
_PARAM_GENERATOR = None
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
def generate_summarization_config(text: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
|
| 222 |
+
pg = get_param_generator()
|
| 223 |
+
if pg is None:
|
| 224 |
words = len(text.split())
|
| 225 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 226 |
mn, mx = defaults[length]
|
| 227 |
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
prompt = (
|
| 230 |
+
"Recommend summarization settings for this text. Answer ONLY with JSON like:\n"
|
| 231 |
'{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
|
| 232 |
"Text:\n'''"
|
| 233 |
+ text[:3000] + "'''"
|
| 234 |
)
|
|
|
|
| 235 |
try:
|
| 236 |
+
out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1)[0]
|
| 237 |
+
out = out_item.get("generated_text") or out_item.get("summary_text") or ""
|
|
|
|
| 238 |
out = (out or "").strip()
|
|
|
|
|
|
|
|
|
|
| 239 |
if not out:
|
| 240 |
raise ValueError("Empty param-generator output")
|
| 241 |
+
# reject noisy echo outputs
|
| 242 |
input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200])
|
| 243 |
out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200])
|
| 244 |
+
if len(input_words) and (len(input_words & out_words) / max(1, len(input_words))) > 0.4:
|
| 245 |
+
logger.warning("Param-generator appears to echo input; using heuristic fallback.")
|
| 246 |
+
words = len(text.split())
|
| 247 |
+
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 248 |
+
mn, mx = defaults[length]
|
| 249 |
+
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 250 |
|
|
|
|
| 251 |
jmatch = re.search(r"\{.*\}", out, re.DOTALL)
|
| 252 |
if jmatch:
|
| 253 |
raw = jmatch.group().replace("'", '"')
|
| 254 |
cfg = json.loads(raw)
|
| 255 |
else:
|
|
|
|
| 256 |
cfg = None
|
| 257 |
|
| 258 |
if not cfg or not isinstance(cfg, dict):
|
| 259 |
+
raise ValueError("Param output not parseable")
|
| 260 |
+
length = cfg.get("length","medium").lower()
|
| 261 |
+
tone = cfg.get("tone","neutral").lower()
|
| 262 |
+
mn = int(cfg.get("min_words") or cfg.get("min_length") or defaults[length][0])
|
| 263 |
+
mx = int(cfg.get("max_words") or cfg.get("max_length") or defaults[length][1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
mn = max(5, min(mn, 2000))
|
| 265 |
mx = max(mn + 5, min(mx, 4000))
|
|
|
|
| 266 |
return {"length": length, "min_length": mn, "max_length": mx, "tone": tone}
|
|
|
|
| 267 |
except Exception as e:
|
| 268 |
+
logger.exception("Param-generator parse failed: %s", e)
|
| 269 |
+
words = len(text.split())
|
| 270 |
+
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 271 |
+
mn, mx = defaults[length]
|
| 272 |
+
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 273 |
|
| 274 |
# -------------------------
|
| 275 |
+
# Threaded chunk summarization with per-chunk timeout (to prevent hang)
|
| 276 |
# -------------------------
|
| 277 |
+
executor = ThreadPoolExecutor(max_workers=min(8, max(2, (os.cpu_count() or 2))))
|
| 278 |
+
CHUNK_TIMEOUT_SECONDS = 28
|
| 279 |
+
REFINE_TIMEOUT_SECONDS = 60
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
futures = {}
|
| 283 |
results = [None] * len(chunks)
|
| 284 |
for idx, chunk in enumerate(chunks):
|
|
|
|
| 287 |
futures[fut] = idx
|
| 288 |
|
| 289 |
start = time.time()
|
| 290 |
+
for fut in as_completed(futures):
|
| 291 |
idx = futures[fut]
|
| 292 |
try:
|
| 293 |
remaining = max(0.1, CHUNK_TIMEOUT_SECONDS - (time.time() - start))
|
|
|
|
| 294 |
results[idx] = fut.result(timeout=remaining)
|
| 295 |
except TimeoutError:
|
| 296 |
logger.warning("Chunk %d timed out; using extractive fallback.", idx)
|
| 297 |
results[idx] = extractive_prefilter(chunks[idx], top_k=3)
|
| 298 |
except Exception as e:
|
| 299 |
+
logger.exception("Chunk %d failed: %s; falling back", idx, e)
|
| 300 |
results[idx] = extractive_prefilter(chunks[idx], top_k=3)
|
|
|
|
| 301 |
for i, r in enumerate(results):
|
| 302 |
if not r:
|
| 303 |
results[i] = extractive_prefilter(chunks[i], top_k=3)
|
| 304 |
return results
|
| 305 |
|
| 306 |
# -------------------------
|
| 307 |
+
# Prompt helpers and refine
|
| 308 |
# -------------------------
|
| 309 |
def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int] = None) -> str:
|
| 310 |
tone = (tone or "neutral").lower()
|
|
|
|
| 312 |
instr = "Produce concise bullet points. Each bullet <= 20 words. No extra commentary."
|
| 313 |
elif tone == "short":
|
| 314 |
ts = target_sentences or 1
|
| 315 |
+
instr = f"Summarize in {ts} sentence{'s' if ts>1 else ''}. Be abstractive."
|
| 316 |
elif tone == "formal":
|
| 317 |
+
instr = "Summarize in a formal, professional tone (2-4 sentences)."
|
| 318 |
elif tone == "casual":
|
| 319 |
+
instr = "Summarize in a casual, conversational tone (1-3 sentences)."
|
| 320 |
elif tone == "long":
|
| 321 |
+
instr = "Provide a structured summary (4-8 sentences)."
|
| 322 |
else:
|
| 323 |
instr = "Summarize in 2-3 clear sentences."
|
| 324 |
instr += " Do not repeat information. Prefer rephrasing."
|
|
|
|
| 329 |
if len(combined.split()) > 1200:
|
| 330 |
combined = extractive_prefilter(combined, top_k=20)
|
| 331 |
prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
|
|
|
|
| 332 |
fut = executor.submit(summarize_with_model, pipe, prompt, short_target=False)
|
| 333 |
try:
|
| 334 |
return fut.result(timeout=REFINE_TIMEOUT_SECONDS)
|
| 335 |
except TimeoutError:
|
| 336 |
+
logger.warning("Refine timed out; returning concatenated chunk summaries.")
|
| 337 |
return " ".join(summaries_list[:6])
|
| 338 |
except Exception as e:
|
| 339 |
+
logger.exception("Refine failed: %s", e)
|
| 340 |
return " ".join(summaries_list[:6])
|
| 341 |
|
| 342 |
# -------------------------
|
| 343 |
+
# Routes
|
| 344 |
# -------------------------
|
| 345 |
@app.route("/", methods=["GET"])
|
| 346 |
def home():
|
| 347 |
try:
|
| 348 |
return render_template("index.html")
|
| 349 |
except Exception:
|
| 350 |
+
return "Summarizer (lazy-load) — POST /summarize with JSON {text:'...'}", 200
|
| 351 |
+
|
| 352 |
+
@app.route("/preload", methods=["POST"])
|
| 353 |
+
def preload_models():
|
| 354 |
+
"""
|
| 355 |
+
Explicit endpoint to attempt preloading heavy models.
|
| 356 |
+
Call this only when you want the process to attempt loading Pegasus/LED (may be slow).
|
| 357 |
+
"""
|
| 358 |
+
results = {}
|
| 359 |
+
for key, model_name in [("pegasus", PEGASUS_MODEL), ("led", LED_MODEL), ("distilbart", DISTILBART_MODEL)]:
|
| 360 |
+
if key in _SUMMARIZER_CACHE:
|
| 361 |
+
results[key] = "already_loaded"
|
| 362 |
+
continue
|
| 363 |
+
try:
|
| 364 |
+
p = safe_load_pipeline(model_name)
|
| 365 |
+
if p:
|
| 366 |
+
_SUMMARIZER_CACHE[key] = p
|
| 367 |
+
results[key] = "loaded"
|
| 368 |
+
else:
|
| 369 |
+
results[key] = "failed"
|
| 370 |
+
except Exception as e:
|
| 371 |
+
results[key] = f"error: {e}"
|
| 372 |
+
return jsonify(results)
|
| 373 |
|
| 374 |
@app.route("/summarize", methods=["POST"])
|
| 375 |
def summarize_route():
|
| 376 |
+
t0 = time.time()
|
| 377 |
data = request.get_json(force=True) or {}
|
| 378 |
text = (data.get("text") or "").strip()[:90000]
|
| 379 |
user_model_pref = (data.get("model") or "auto").lower()
|
|
|
|
| 383 |
if not text or len(text.split()) < 5:
|
| 384 |
return jsonify({"error": "Input too short."}), 400
|
| 385 |
|
| 386 |
+
# decide settings
|
| 387 |
+
if requested_length in ("auto","ai") or requested_tone in ("auto","ai"):
|
| 388 |
cfg = generate_summarization_config(text)
|
| 389 |
+
length_choice = cfg.get("length","medium")
|
| 390 |
+
tone_choice = cfg.get("tone","neutral")
|
| 391 |
else:
|
| 392 |
length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
|
| 393 |
tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
|
| 394 |
|
| 395 |
+
# model selection logic
|
| 396 |
words_len = len(text.split())
|
| 397 |
prefer_led = False
|
| 398 |
if user_model_pref == "led":
|
|
|
|
| 402 |
else:
|
| 403 |
if length_choice == "long" or words_len > 3000:
|
| 404 |
prefer_led = True
|
|
|
|
|
|
|
| 405 |
|
|
|
|
| 406 |
model_key = "led" if prefer_led else (_PREFERRED_SUMMARIZER_KEY or "distilbart")
|
| 407 |
+
try:
|
| 408 |
+
summarizer_pipe = get_summarizer(model_key)
|
| 409 |
+
except Exception as e:
|
| 410 |
+
logger.exception("get_summarizer failed (%s). Falling back to distilbart.", e)
|
| 411 |
+
summarizer_pipe = get_summarizer("distilbart")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
model_key = "distilbart"
|
|
|
|
| 413 |
|
| 414 |
+
# prefilter very long inputs for non-LED
|
| 415 |
if model_key != "led" and words_len > 2500:
|
| 416 |
text_for_chunks = extractive_prefilter(text, top_k=40)
|
| 417 |
else:
|
| 418 |
text_for_chunks = text
|
| 419 |
|
| 420 |
+
# chunk sizing
|
| 421 |
if model_key == "led":
|
| 422 |
chunk_max = 6000
|
| 423 |
overlap = 400
|
| 424 |
else:
|
| 425 |
+
chunk_max = 800
|
| 426 |
overlap = 120
|
| 427 |
|
| 428 |
chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max, overlap=overlap)
|
| 429 |
+
chunk_target = 1 if length_choice == "short" else 2
|
| 430 |
|
| 431 |
+
# summarize chunks in parallel
|
|
|
|
|
|
|
| 432 |
try:
|
| 433 |
+
chunk_summaries = summarize_chunks_parallel(summarizer_pipe, chunks, chunk_target)
|
| 434 |
except Exception as e:
|
| 435 |
logger.exception("Chunk summarization orchestration failed: %s", e)
|
|
|
|
| 436 |
chunk_summaries = [extractive_prefilter(c, top_k=3) for c in chunks]
|
|
|
|
| 437 |
|
| 438 |
+
# refine step — prefer Pegasus if loaded, otherwise use current pipe
|
| 439 |
refine_pipe = _SUMMARIZER_CACHE.get("pegasus") or summarizer_pipe
|
| 440 |
+
final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice, 3)
|
| 441 |
+
final = refine_combined(refine_pipe, chunk_summaries, tone_choice, final_target_sentences)
|
| 442 |
|
| 443 |
+
# bullet postprocess
|
| 444 |
if tone_choice == "bullet":
|
| 445 |
parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
|
| 446 |
bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
|
| 447 |
final = "\n".join(bullets[:20])
|
| 448 |
|
| 449 |
+
elapsed = time.time() - t0
|
| 450 |
meta = {
|
| 451 |
"length_choice": length_choice,
|
| 452 |
"tone": tone_choice,
|
|
|
|
| 453 |
"model_used": model_key,
|
| 454 |
"chunks": len(chunks),
|
| 455 |
"input_words": words_len,
|
| 456 |
"time_seconds": round(elapsed, 2),
|
| 457 |
+
"device": "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
}
|
|
|
|
| 459 |
return jsonify({"summary": final, "meta": meta})
|
| 460 |
|
| 461 |
# -------------------------
|
| 462 |
+
# Local run (safe)
|
| 463 |
# -------------------------
|
| 464 |
if __name__ == "__main__":
|
| 465 |
+
# For local testing you may call preload_models_at_startup manually or use /preload.
|
| 466 |
app.run(host="0.0.0.0", port=7860, debug=False)
|