sayanAIAI commited on
Commit
1045143
·
verified ·
1 Parent(s): 440165f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +184 -207
main.py CHANGED
@@ -1,4 +1,4 @@
1
- # main.py
2
  import os
3
  os.environ['HF_HOME'] = '/tmp'
4
 
@@ -22,27 +22,26 @@ logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger("summarizer")
23
 
24
  # -------------------------
25
- # Device selection (CPU-only explicitly per your request)
26
  # -------------------------
27
- # The user asked for CPU — force device -1. If you later enable GPU, set DEVICE accordingly.
28
  USE_GPU = False
29
  DEVICE = -1
30
- logger.info("Startup: forcing CPU usage for all models (DEVICE=%s)", DEVICE)
31
 
32
  # -------------------------
33
  # Model names and caches
34
  # -------------------------
35
  PEGASUS_MODEL = "google/pegasus-large"
36
  LED_MODEL = "allenai/led-large-16384"
37
- FALLBACK_MODEL = "sshleifer/distilbart-cnn-12-6"
38
- PARAM_MODEL = "google/flan-t5-small" # small instruction model
39
 
40
  _SUMMARIZER_CACHE: Dict[str, Any] = {}
41
  _PARAM_GENERATOR = None
42
  _PREFERRED_SUMMARIZER_KEY: Optional[str] = None
43
 
44
  # -------------------------
45
- # Utilities
46
  # -------------------------
47
  _STOPWORDS = {
48
  "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
@@ -68,7 +67,6 @@ def extractive_prefilter(text: str, top_k: int = 6) -> str:
68
  return " ".join(chosen)
69
 
70
  def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]:
71
- # small chunks to keep CPU generation per-call bounded
72
  n = len(text)
73
  if n <= max_chars:
74
  return [text]
@@ -89,97 +87,98 @@ def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) ->
89
  start = end
90
  return [p for p in parts if p]
91
 
92
- def _first_int_from_text(s: str, fallback: Optional[int] = None) -> Optional[int]:
93
- m = re.search(r"\d{1,4}", s)
94
- return int(m.group()) if m else fallback
95
-
96
  # -------------------------
97
- # Safe model loading (preload on startup)
98
  # -------------------------
99
- def safe_load_param_generator():
100
- global _PARAM_GENERATOR
101
- try:
102
- logger.info("Loading param-generator (text2text) model: %s", PARAM_MODEL)
103
- p_tok = AutoTokenizer.from_pretrained(PARAM_MODEL)
104
- p_mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
105
- # IMPORTANT: use text2text-generation so outputs are in generated_text (not summary_text)
106
- _PARAM_GENERATOR = pipeline("text2text-generation", model=p_mod, tokenizer=p_tok, device=DEVICE)
107
- logger.info("Param-generator loaded as text2text-generation.")
108
- except Exception as e:
109
- logger.exception("Param-generator failed to load as text2text: %s", e)
110
- _PARAM_GENERATOR = None
111
-
112
- def preload_models_at_startup():
113
- global _PREFERRED_SUMMARIZER_KEY
114
- # Preload Pegasus and LED (best-effort). If fail, load fallback.
115
- logger.info("Preloading summarizer models (Pegasus, LED, fallback)...")
116
- p = safe_load_pipeline(PEGASUS_MODEL)
117
- if p:
118
- _SUMMARIZER_CACHE["pegasus"] = p
119
- _PREFERRED_SUMMARIZER_KEY = "pegasus"
120
- else:
121
- logger.warning("Pegasus failed to load on CPU; will still attempt LED and fallback.")
122
-
123
- l = safe_load_pipeline(LED_MODEL)
124
- if l:
125
- _SUMMARIZER_CACHE["led"] = l
126
-
127
- # Always ensure fallback available
128
- f = safe_load_pipeline(FALLBACK_MODEL)
129
- if f:
130
- _SUMMARIZER_CACHE["distilbart"] = f
131
- if not _PREFERRED_SUMMARIZER_KEY:
132
- _PREFERRED_SUMMARIZER_KEY = "distilbart"
133
- else:
134
- logger.critical("Fallback model failed to load. The app will not be able to summarize.")
135
-
136
- # Load param generator if possible (small model)
137
  try:
138
- pg = safe_load_pipeline(PARAM_MODEL) # we reuse safe loader but it's summarization pipeline shape; OK for short generation
139
- # prefer to keep the pipeline object (works as text2text too)
140
- if pg:
141
- _PARAM_GENERATOR = pg
142
- logger.info("Param-generator loaded (via safe_load_pipeline).")
143
- else:
144
- _PARAM_GENERATOR = None
145
- except Exception as e:
146
- logger.exception("Param-generator load failed: %s", e)
147
- _PARAM_GENERATOR = None
148
-
149
- preload_models_at_startup()
150
-
151
- # If _PARAM_GENERATOR exists but is a summarization pipeline, that's okay for our short prompt usage.
 
 
 
152
 
153
  # -------------------------
154
- # Helpers: generation strategy (fast-first, quality-fallback)
155
  # -------------------------
156
- def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) -> str:
157
  """
158
- Fast-first then quality fallback. This function uses conservative settings to keep CPU latency bounded.
159
- It is intended to be called inside a per-chunk ThreadPool with timeout control externally.
160
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  model_name = getattr(pipe.model.config, "name_or_path", "") or ""
162
- model_name_lower = model_name.lower()
163
- is_led = "led" in model_name_lower or "longformer" in model_name_lower
164
 
165
- # FAST sampling pass (usually quicker than beams on CPU)
166
  fast_cfg = {
167
  "max_new_tokens": 64 if short_target else (120 if not is_led else 240),
168
  "do_sample": True,
169
- "top_p": 0.90,
170
  "temperature": 0.85,
171
  "num_beams": 1,
172
  "early_stopping": True,
173
  "no_repeat_ngram_size": 3,
174
  }
175
-
176
  try:
177
- out = pipe(text_prompt, **fast_cfg)[0]["summary_text"].strip()
178
- return out
179
  except Exception as e:
180
- logger.warning("Fast pass failed: %s; trying quality pass...", e)
181
 
182
- # QUALITY beam pass (more expensive)
183
  quality_cfg = {
184
  "max_new_tokens": 140 if not is_led else 320,
185
  "do_sample": False,
@@ -187,121 +186,99 @@ def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) ->
187
  "early_stopping": True,
188
  "no_repeat_ngram_size": 3,
189
  }
190
-
191
  try:
192
- out = pipe(text_prompt, **quality_cfg)[0]["summary_text"].strip()
193
- return out
194
- except Exception as e2:
195
- logger.exception("Quality pass failed: %s", e2)
196
 
197
- # extractive fallback
198
  try:
199
  return extractive_prefilter(text_prompt, top_k=3)
200
  except Exception:
201
- return "Summarization failed; please try shorter input."
202
 
203
  # -------------------------
204
- # Param generator (AI decision) with fallback heuristic
205
  # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def generate_summarization_config(text: str) -> Dict[str, Any]:
207
- """
208
- Uses the text2text param-generator to output a JSON config.
209
- If the generator fails or returns something noisy (e.g., echoes the input),
210
- fall back to a safe heuristic.
211
- """
212
  defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
213
- # heuristic fallback
214
- def fallback():
215
  words = len(text.split())
216
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
217
  mn, mx = defaults[length]
218
  return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
219
 
220
- pg = _PARAM_GENERATOR
221
- if pg is None:
222
- logger.info("Param-generator not available; using fallback heuristic.")
223
- return fallback()
224
-
225
  prompt = (
226
- "Recommend summarization settings for this text. Answer ONLY with JSON of the form:\n"
227
  '{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
228
  "Text:\n'''"
229
  + text[:3000] + "'''"
230
  )
231
-
232
  try:
233
- out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1, early_stopping=True)[0]
234
- # different pipeline versions may return different keys; check both:
235
- out = out_item.get("generated_text") or out_item.get("summary_text") or out_item.get("text") or ""
236
  out = (out or "").strip()
237
-
238
- # COMMON FAILURE MODE: the model just echoes the input — reject that
239
- # If output contains a long substring of the input, treat as invalid.
240
  if not out:
241
  raise ValueError("Empty param-generator output")
242
- # If the returned text contains more than 40% of the original input words, treat it as echo
243
  input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200])
244
  out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200])
245
- if len(input_words) > 0 and (len(input_words & out_words) / max(1, len(input_words))) > 0.4:
246
- logger.warning("Param-generator appears to echo input; discarding and using heuristic.")
247
- return fallback()
 
 
 
248
 
249
- # Find JSON object in output
250
  jmatch = re.search(r"\{.*\}", out, re.DOTALL)
251
  if jmatch:
252
  raw = jmatch.group().replace("'", '"')
253
  cfg = json.loads(raw)
254
  else:
255
- # attempt to parse line with key:value pairs (tolerant)
256
  cfg = None
257
 
258
  if not cfg or not isinstance(cfg, dict):
259
- logger.warning("Param-generator output not parseable as JSON: %s", out[:300])
260
- return fallback()
261
-
262
- length = cfg.get("length", "medium").lower()
263
- tone = cfg.get("tone", "neutral").lower()
264
- min_w = cfg.get("min_words") or cfg.get("min_length") or cfg.get("min")
265
- max_w = cfg.get("max_words") or cfg.get("max_length") or cfg.get("max")
266
-
267
- if length not in ("short","medium","long"):
268
- words = len(text.split())
269
- length = "short" if words < 150 else ("medium" if words < 800 else "long")
270
- if tone not in ("neutral","formal","casual","bullet"):
271
- tone = "neutral"
272
-
273
- defaults_min, defaults_max = defaults.get(length, (50,130))
274
- try:
275
- mn = int(min_w) if min_w is not None else defaults_min
276
- mx = int(max_w) if max_w is not None else defaults_max
277
- except Exception:
278
- mn, mx = defaults_min, defaults_max
279
-
280
  mn = max(5, min(mn, 2000))
281
  mx = max(mn + 5, min(mx, 4000))
282
- logger.info("Param-generator suggested length=%s tone=%s min=%s max=%s", length, tone, mn, mx)
283
  return {"length": length, "min_length": mn, "max_length": mx, "tone": tone}
284
-
285
  except Exception as e:
286
- logger.exception("Param-generator failed to produce usable config: %s", e)
287
- return fallback()
 
 
 
288
 
289
  # -------------------------
290
- # Orchestrator: chunk summarization with threadpool + timeouts
291
  # -------------------------
292
- # Tunable parameters
293
- MAX_WORKERS = min(8, max(2, (os.cpu_count() or 2))) # number of threads for parallel chunk work
294
- CHUNK_TIMEOUT_SECONDS = 28 # per-chunk timeout (safe lower than common gunicorn timeout)
295
- REFINE_TIMEOUT_SECONDS = 60 # final refinement timeout
296
- MAX_TOTAL_SECONDS = 180 # overall safety cap for a request
297
-
298
- executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
299
 
300
  def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> List[str]:
301
- """
302
- Submit chunk summarization tasks to threadpool; apply per-chunk timeout.
303
- If a chunk times out or fails, use extractive_prefilter fallback.
304
- """
305
  futures = {}
306
  results = [None] * len(chunks)
307
  for idx, chunk in enumerate(chunks):
@@ -310,26 +287,24 @@ def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> Lis
310
  futures[fut] = idx
311
 
312
  start = time.time()
313
- for fut in as_completed(futures, timeout=None):
314
  idx = futures[fut]
315
  try:
316
  remaining = max(0.1, CHUNK_TIMEOUT_SECONDS - (time.time() - start))
317
- # bound waiting for each future to CHUNK_TIMEOUT_SECONDS
318
  results[idx] = fut.result(timeout=remaining)
319
  except TimeoutError:
320
  logger.warning("Chunk %d timed out; using extractive fallback.", idx)
321
  results[idx] = extractive_prefilter(chunks[idx], top_k=3)
322
  except Exception as e:
323
- logger.exception("Chunk %d failed: %s; using extractive fallback.", idx, e)
324
  results[idx] = extractive_prefilter(chunks[idx], top_k=3)
325
- # ensure no None
326
  for i, r in enumerate(results):
327
  if not r:
328
  results[i] = extractive_prefilter(chunks[i], top_k=3)
329
  return results
330
 
331
  # -------------------------
332
- # Prompt builder and refine
333
  # -------------------------
334
  def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int] = None) -> str:
335
  tone = (tone or "neutral").lower()
@@ -337,13 +312,13 @@ def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int]
337
  instr = "Produce concise bullet points. Each bullet <= 20 words. No extra commentary."
338
  elif tone == "short":
339
  ts = target_sentences or 1
340
- instr = f"Summarize in {ts} sentence{'s' if ts>1 else ''}. Be abstractive and avoid copying."
341
  elif tone == "formal":
342
- instr = "Summarize in a formal professional tone in 2-4 sentences."
343
  elif tone == "casual":
344
- instr = "Summarize in a casual tone in 1-3 sentences."
345
  elif tone == "long":
346
- instr = "Provide 4-8 sentence structured summary covering key points."
347
  else:
348
  instr = "Summarize in 2-3 clear sentences."
349
  instr += " Do not repeat information. Prefer rephrasing."
@@ -354,30 +329,51 @@ def refine_combined(pipe, summaries_list: List[str], tone: str, final_target_sen
354
  if len(combined.split()) > 1200:
355
  combined = extractive_prefilter(combined, top_k=20)
356
  prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
357
- # run refine in executor to apply timeout
358
  fut = executor.submit(summarize_with_model, pipe, prompt, short_target=False)
359
  try:
360
  return fut.result(timeout=REFINE_TIMEOUT_SECONDS)
361
  except TimeoutError:
362
- logger.warning("Refine pass timed out; returning concatenated chunk summaries as fallback.")
363
  return " ".join(summaries_list[:6])
364
  except Exception as e:
365
- logger.exception("Refine pass failed: %s", e)
366
  return " ".join(summaries_list[:6])
367
 
368
  # -------------------------
369
- # Endpoint
370
  # -------------------------
371
  @app.route("/", methods=["GET"])
372
  def home():
373
  try:
374
  return render_template("index.html")
375
  except Exception:
376
- return "Summarizer (CPU-ready) — POST /summarize with JSON {text: '...'}", 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  @app.route("/summarize", methods=["POST"])
379
  def summarize_route():
380
- t_start = time.time()
381
  data = request.get_json(force=True) or {}
382
  text = (data.get("text") or "").strip()[:90000]
383
  user_model_pref = (data.get("model") or "auto").lower()
@@ -387,16 +383,16 @@ def summarize_route():
387
  if not text or len(text.split()) < 5:
388
  return jsonify({"error": "Input too short."}), 400
389
 
390
- # Decide settings
391
- if requested_length in ("auto", "ai") or requested_tone in ("auto", "ai"):
392
  cfg = generate_summarization_config(text)
393
- length_choice = cfg.get("length", "medium")
394
- tone_choice = cfg.get("tone", "neutral")
395
  else:
396
  length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
397
  tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
398
 
399
- # Model selection (user or auto)
400
  words_len = len(text.split())
401
  prefer_led = False
402
  if user_model_pref == "led":
@@ -406,84 +402,65 @@ def summarize_route():
406
  else:
407
  if length_choice == "long" or words_len > 3000:
408
  prefer_led = True
409
- else:
410
- prefer_led = False
411
 
412
- # Choose actual pipe
413
  model_key = "led" if prefer_led else (_PREFERRED_SUMMARIZER_KEY or "distilbart")
414
- if model_key not in _SUMMARIZER_CACHE:
415
- # try to load it (may fail on CPU but we attempt)
416
- try:
417
- pipe = safe_load_pipeline(LED_MODEL if model_key=="led" else PEGASUS_MODEL if model_key=="pegasus" else FALLBACK_MODEL)
418
- if pipe:
419
- _SUMMARIZER_CACHE[model_key] = pipe
420
- except Exception as e:
421
- logger.exception("On-demand load failed: %s", e)
422
-
423
- # Ensure we have at least a fallback
424
- if model_key not in _SUMMARIZER_CACHE:
425
  model_key = "distilbart"
426
- summarizer_pipe = _SUMMARIZER_CACHE[model_key]
427
 
428
- # Prepare text for chunks
429
  if model_key != "led" and words_len > 2500:
430
  text_for_chunks = extractive_prefilter(text, top_k=40)
431
  else:
432
  text_for_chunks = text
433
 
434
- # Chunk sizing
435
  if model_key == "led":
436
  chunk_max = 6000
437
  overlap = 400
438
  else:
439
- chunk_max = 800 # small chunks help CPU reliability
440
  overlap = 120
441
 
442
  chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max, overlap=overlap)
 
443
 
444
- # Summarize chunks in parallel with per-chunk timeouts
445
- chunk_target_sentences = 1 if length_choice == "short" else 2
446
- t_chunks_start = time.time()
447
  try:
448
- chunk_summaries = summarize_chunks_parallel(summarizer_pipe, chunks, chunk_target_sentences)
449
  except Exception as e:
450
  logger.exception("Chunk summarization orchestration failed: %s", e)
451
- # fallback: simple extractive split
452
  chunk_summaries = [extractive_prefilter(c, top_k=3) for c in chunks]
453
- t_chunks_end = time.time()
454
 
455
- # Final refine: prefer Pegasus if available in cache (for nicer prose), else current pipe
456
  refine_pipe = _SUMMARIZER_CACHE.get("pegasus") or summarizer_pipe
457
- final_target = {"short":1,"medium":3,"long":6}.get(length_choice, 3)
458
- final = refine_combined(refine_pipe, chunk_summaries, tone_choice, final_target)
459
 
460
- # Bullet postprocess
461
  if tone_choice == "bullet":
462
  parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
463
  bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
464
  final = "\n".join(bullets[:20])
465
 
466
- elapsed = time.time() - t_start
467
  meta = {
468
  "length_choice": length_choice,
469
  "tone": tone_choice,
470
- "model_requested": user_model_pref,
471
  "model_used": model_key,
472
  "chunks": len(chunks),
473
  "input_words": words_len,
474
  "time_seconds": round(elapsed, 2),
475
- "time_chunk_phase": round(t_chunks_end - t_chunks_start, 2),
476
- "device": "cpu",
477
- "workers_threads": MAX_WORKERS,
478
- "per_chunk_timeout": CHUNK_TIMEOUT_SECONDS,
479
- "refine_timeout": REFINE_TIMEOUT_SECONDS
480
  }
481
-
482
  return jsonify({"summary": final, "meta": meta})
483
 
484
  # -------------------------
485
- # Run for local testing
486
  # -------------------------
487
  if __name__ == "__main__":
488
- # For production use Gunicorn with an increased timeout (e.g., gunicorn_conf.py with timeout=180).
489
  app.run(host="0.0.0.0", port=7860, debug=False)
 
1
+ # main.py (REPLACE your existing file with this)
2
  import os
3
  os.environ['HF_HOME'] = '/tmp'
4
 
 
22
  logger = logging.getLogger("summarizer")
23
 
24
  # -------------------------
25
+ # Device selection (CPU by default)
26
  # -------------------------
 
27
  USE_GPU = False
28
  DEVICE = -1
29
+ logger.info("Startup: forcing CPU usage for models (DEVICE=%s)", DEVICE)
30
 
31
  # -------------------------
32
  # Model names and caches
33
  # -------------------------
34
  PEGASUS_MODEL = "google/pegasus-large"
35
  LED_MODEL = "allenai/led-large-16384"
36
+ DISTILBART_MODEL = "sshleifer/distilbart-cnn-12-6"
37
+ PARAM_MODEL = "google/flan-t5-small"
38
 
39
  _SUMMARIZER_CACHE: Dict[str, Any] = {}
40
  _PARAM_GENERATOR = None
41
  _PREFERRED_SUMMARIZER_KEY: Optional[str] = None
42
 
43
  # -------------------------
44
+ # Utilities: chunking, extractive fallback
45
  # -------------------------
46
  _STOPWORDS = {
47
  "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
 
67
  return " ".join(chosen)
68
 
69
  def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]:
 
70
  n = len(text)
71
  if n <= max_chars:
72
  return [text]
 
87
  start = end
88
  return [p for p in parts if p]
89
 
 
 
 
 
90
  # -------------------------
91
+ # safe loader (defined before any calls)
92
  # -------------------------
93
+ def safe_load_pipeline(model_name: str):
94
+ """
95
+ Try to load a summarization pipeline robustly:
96
+ - try fast tokenizer first
97
+ - if that fails, try use_fast=False
98
+ - return pipeline or None if both fail
99
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  try:
101
+ logger.info("Loading tokenizer/model for %s (fast)...", model_name)
102
+ tok = AutoTokenizer.from_pretrained(model_name)
103
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
104
+ pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
105
+ logger.info("Loaded %s (fast tokenizer)", model_name)
106
+ return pipe
107
+ except Exception as e_fast:
108
+ logger.warning("Fast tokenizer failed for %s: %s. Trying slow tokenizer...", model_name, e_fast)
109
+ try:
110
+ tok = AutoTokenizer.from_pretrained(model_name, use_fast=False)
111
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
112
+ pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
113
+ logger.info("Loaded %s (slow tokenizer)", model_name)
114
+ return pipe
115
+ except Exception as e_slow:
116
+ logger.exception("Slow tokenizer failed for %s: %s", model_name, e_slow)
117
+ return None
118
 
119
  # -------------------------
120
+ # get_summarizer: lazy load + cache + fallback
121
  # -------------------------
122
+ def get_summarizer(key: str):
123
  """
124
+ key: 'pegasus'|'led'|'distilbart'|'auto'
125
+ returns a pipeline (cached), or raises RuntimeError if no pipeline can be loaded.
126
  """
127
+ key = (key or "auto").lower()
128
+ if key == "auto":
129
+ key = _PREFERRED_SUMMARIZER_KEY or "distilbart"
130
+
131
+ # direct mapping
132
+ model_name = {
133
+ "pegasus": PEGASUS_MODEL,
134
+ "led": LED_MODEL,
135
+ "distilbart": DISTILBART_MODEL
136
+ }.get(key, DISTILBART_MODEL)
137
+
138
+ if key in _SUMMARIZER_CACHE:
139
+ return _SUMMARIZER_CACHE[key]
140
+
141
+ # try to load
142
+ logger.info("Attempting to lazy-load summarizer '%s' -> %s", key, model_name)
143
+ pipe = safe_load_pipeline(model_name)
144
+ if pipe:
145
+ _SUMMARIZER_CACHE[key] = pipe
146
+ return pipe
147
+
148
+ # fallback attempts
149
+ logger.warning("Failed to load %s. Trying distilbart fallback.", key)
150
+ if "distilbart" in _SUMMARIZER_CACHE:
151
+ return _SUMMARIZER_CACHE["distilbart"]
152
+ fb = safe_load_pipeline(DISTILBART_MODEL)
153
+ if fb:
154
+ _SUMMARIZER_CACHE["distilbart"] = fb
155
+ return fb
156
+
157
+ # nothing works
158
+ raise RuntimeError("No summarizer available. Install required libraries and/or choose smaller model.")
159
+
160
+ # -------------------------
161
+ # Generation strategy + small helpers
162
+ # -------------------------
163
+ def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) -> str:
164
  model_name = getattr(pipe.model.config, "name_or_path", "") or ""
165
+ is_led = "led" in model_name.lower() or "longformer" in model_name.lower()
 
166
 
167
+ # Fast sampling pass
168
  fast_cfg = {
169
  "max_new_tokens": 64 if short_target else (120 if not is_led else 240),
170
  "do_sample": True,
171
+ "top_p": 0.92,
172
  "temperature": 0.85,
173
  "num_beams": 1,
174
  "early_stopping": True,
175
  "no_repeat_ngram_size": 3,
176
  }
 
177
  try:
178
+ return pipe(text_prompt, **fast_cfg)[0].get("summary_text","").strip()
 
179
  except Exception as e:
180
+ logger.warning("Fast pass failed: %s, trying quality pass...", e)
181
 
 
182
  quality_cfg = {
183
  "max_new_tokens": 140 if not is_led else 320,
184
  "do_sample": False,
 
186
  "early_stopping": True,
187
  "no_repeat_ngram_size": 3,
188
  }
 
189
  try:
190
+ return pipe(text_prompt, **quality_cfg)[0].get("summary_text","").strip()
191
+ except Exception as e:
192
+ logger.exception("Quality pass failed: %s", e)
 
193
 
194
+ # fallback extractive
195
  try:
196
  return extractive_prefilter(text_prompt, top_k=3)
197
  except Exception:
198
+ return "Summarization failed; try shorter input."
199
 
200
  # -------------------------
201
+ # Param generator (AI decision) - lazy loader
202
  # -------------------------
203
+ def get_param_generator():
204
+ global _PARAM_GENERATOR
205
+ if _PARAM_GENERATOR is not None:
206
+ return _PARAM_GENERATOR
207
+ # try to load text2text pipeline for PARAM_MODEL
208
+ try:
209
+ logger.info("Loading param-generator (text2text) lazily: %s", PARAM_MODEL)
210
+ tok = AutoTokenizer.from_pretrained(PARAM_MODEL)
211
+ mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
212
+ _PARAM_GENERATOR = pipeline("text2text-generation", model=mod, tokenizer=tok, device=DEVICE)
213
+ logger.info("Param-generator loaded.")
214
+ return _PARAM_GENERATOR
215
+ except Exception as e:
216
+ logger.exception("Param-generator lazy load failed: %s", e)
217
+ _PARAM_GENERATOR = None
218
+ return None
219
+
220
  def generate_summarization_config(text: str) -> Dict[str, Any]:
 
 
 
 
 
221
  defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
222
+ pg = get_param_generator()
223
+ if pg is None:
224
  words = len(text.split())
225
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
226
  mn, mx = defaults[length]
227
  return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
228
 
 
 
 
 
 
229
  prompt = (
230
+ "Recommend summarization settings for this text. Answer ONLY with JSON like:\n"
231
  '{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
232
  "Text:\n'''"
233
  + text[:3000] + "'''"
234
  )
 
235
  try:
236
+ out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1)[0]
237
+ out = out_item.get("generated_text") or out_item.get("summary_text") or ""
 
238
  out = (out or "").strip()
 
 
 
239
  if not out:
240
  raise ValueError("Empty param-generator output")
241
+ # reject noisy echo outputs
242
  input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200])
243
  out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200])
244
+ if len(input_words) and (len(input_words & out_words) / max(1, len(input_words))) > 0.4:
245
+ logger.warning("Param-generator appears to echo input; using heuristic fallback.")
246
+ words = len(text.split())
247
+ length = "short" if words < 150 else ("medium" if words < 800 else "long")
248
+ mn, mx = defaults[length]
249
+ return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
250
 
 
251
  jmatch = re.search(r"\{.*\}", out, re.DOTALL)
252
  if jmatch:
253
  raw = jmatch.group().replace("'", '"')
254
  cfg = json.loads(raw)
255
  else:
 
256
  cfg = None
257
 
258
  if not cfg or not isinstance(cfg, dict):
259
+ raise ValueError("Param output not parseable")
260
+ length = cfg.get("length","medium").lower()
261
+ tone = cfg.get("tone","neutral").lower()
262
+ mn = int(cfg.get("min_words") or cfg.get("min_length") or defaults[length][0])
263
+ mx = int(cfg.get("max_words") or cfg.get("max_length") or defaults[length][1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  mn = max(5, min(mn, 2000))
265
  mx = max(mn + 5, min(mx, 4000))
 
266
  return {"length": length, "min_length": mn, "max_length": mx, "tone": tone}
 
267
  except Exception as e:
268
+ logger.exception("Param-generator parse failed: %s", e)
269
+ words = len(text.split())
270
+ length = "short" if words < 150 else ("medium" if words < 800 else "long")
271
+ mn, mx = defaults[length]
272
+ return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
273
 
274
  # -------------------------
275
+ # Threaded chunk summarization with per-chunk timeout (to prevent hang)
276
  # -------------------------
277
+ executor = ThreadPoolExecutor(max_workers=min(8, max(2, (os.cpu_count() or 2))))
278
+ CHUNK_TIMEOUT_SECONDS = 28
279
+ REFINE_TIMEOUT_SECONDS = 60
 
 
 
 
280
 
281
  def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> List[str]:
 
 
 
 
282
  futures = {}
283
  results = [None] * len(chunks)
284
  for idx, chunk in enumerate(chunks):
 
287
  futures[fut] = idx
288
 
289
  start = time.time()
290
+ for fut in as_completed(futures):
291
  idx = futures[fut]
292
  try:
293
  remaining = max(0.1, CHUNK_TIMEOUT_SECONDS - (time.time() - start))
 
294
  results[idx] = fut.result(timeout=remaining)
295
  except TimeoutError:
296
  logger.warning("Chunk %d timed out; using extractive fallback.", idx)
297
  results[idx] = extractive_prefilter(chunks[idx], top_k=3)
298
  except Exception as e:
299
+ logger.exception("Chunk %d failed: %s; falling back", idx, e)
300
  results[idx] = extractive_prefilter(chunks[idx], top_k=3)
 
301
  for i, r in enumerate(results):
302
  if not r:
303
  results[i] = extractive_prefilter(chunks[i], top_k=3)
304
  return results
305
 
306
  # -------------------------
307
+ # Prompt helpers and refine
308
  # -------------------------
309
  def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int] = None) -> str:
310
  tone = (tone or "neutral").lower()
 
312
  instr = "Produce concise bullet points. Each bullet <= 20 words. No extra commentary."
313
  elif tone == "short":
314
  ts = target_sentences or 1
315
+ instr = f"Summarize in {ts} sentence{'s' if ts>1 else ''}. Be abstractive."
316
  elif tone == "formal":
317
+ instr = "Summarize in a formal, professional tone (2-4 sentences)."
318
  elif tone == "casual":
319
+ instr = "Summarize in a casual, conversational tone (1-3 sentences)."
320
  elif tone == "long":
321
+ instr = "Provide a structured summary (4-8 sentences)."
322
  else:
323
  instr = "Summarize in 2-3 clear sentences."
324
  instr += " Do not repeat information. Prefer rephrasing."
 
329
  if len(combined.split()) > 1200:
330
  combined = extractive_prefilter(combined, top_k=20)
331
  prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
 
332
  fut = executor.submit(summarize_with_model, pipe, prompt, short_target=False)
333
  try:
334
  return fut.result(timeout=REFINE_TIMEOUT_SECONDS)
335
  except TimeoutError:
336
+ logger.warning("Refine timed out; returning concatenated chunk summaries.")
337
  return " ".join(summaries_list[:6])
338
  except Exception as e:
339
+ logger.exception("Refine failed: %s", e)
340
  return " ".join(summaries_list[:6])
341
 
342
  # -------------------------
343
+ # Routes
344
  # -------------------------
345
  @app.route("/", methods=["GET"])
346
  def home():
347
  try:
348
  return render_template("index.html")
349
  except Exception:
350
+ return "Summarizer (lazy-load) — POST /summarize with JSON {text:'...'}", 200
351
+
352
+ @app.route("/preload", methods=["POST"])
353
+ def preload_models():
354
+ """
355
+ Explicit endpoint to attempt preloading heavy models.
356
+ Call this only when you want the process to attempt loading Pegasus/LED (may be slow).
357
+ """
358
+ results = {}
359
+ for key, model_name in [("pegasus", PEGASUS_MODEL), ("led", LED_MODEL), ("distilbart", DISTILBART_MODEL)]:
360
+ if key in _SUMMARIZER_CACHE:
361
+ results[key] = "already_loaded"
362
+ continue
363
+ try:
364
+ p = safe_load_pipeline(model_name)
365
+ if p:
366
+ _SUMMARIZER_CACHE[key] = p
367
+ results[key] = "loaded"
368
+ else:
369
+ results[key] = "failed"
370
+ except Exception as e:
371
+ results[key] = f"error: {e}"
372
+ return jsonify(results)
373
 
374
  @app.route("/summarize", methods=["POST"])
375
  def summarize_route():
376
+ t0 = time.time()
377
  data = request.get_json(force=True) or {}
378
  text = (data.get("text") or "").strip()[:90000]
379
  user_model_pref = (data.get("model") or "auto").lower()
 
383
  if not text or len(text.split()) < 5:
384
  return jsonify({"error": "Input too short."}), 400
385
 
386
+ # decide settings
387
+ if requested_length in ("auto","ai") or requested_tone in ("auto","ai"):
388
  cfg = generate_summarization_config(text)
389
+ length_choice = cfg.get("length","medium")
390
+ tone_choice = cfg.get("tone","neutral")
391
  else:
392
  length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
393
  tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
394
 
395
+ # model selection logic
396
  words_len = len(text.split())
397
  prefer_led = False
398
  if user_model_pref == "led":
 
402
  else:
403
  if length_choice == "long" or words_len > 3000:
404
  prefer_led = True
 
 
405
 
 
406
  model_key = "led" if prefer_led else (_PREFERRED_SUMMARIZER_KEY or "distilbart")
407
+ try:
408
+ summarizer_pipe = get_summarizer(model_key)
409
+ except Exception as e:
410
+ logger.exception("get_summarizer failed (%s). Falling back to distilbart.", e)
411
+ summarizer_pipe = get_summarizer("distilbart")
 
 
 
 
 
 
412
  model_key = "distilbart"
 
413
 
414
+ # prefilter very long inputs for non-LED
415
  if model_key != "led" and words_len > 2500:
416
  text_for_chunks = extractive_prefilter(text, top_k=40)
417
  else:
418
  text_for_chunks = text
419
 
420
+ # chunk sizing
421
  if model_key == "led":
422
  chunk_max = 6000
423
  overlap = 400
424
  else:
425
+ chunk_max = 800
426
  overlap = 120
427
 
428
  chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max, overlap=overlap)
429
+ chunk_target = 1 if length_choice == "short" else 2
430
 
431
+ # summarize chunks in parallel
 
 
432
  try:
433
+ chunk_summaries = summarize_chunks_parallel(summarizer_pipe, chunks, chunk_target)
434
  except Exception as e:
435
  logger.exception("Chunk summarization orchestration failed: %s", e)
 
436
  chunk_summaries = [extractive_prefilter(c, top_k=3) for c in chunks]
 
437
 
438
+ # refine step — prefer Pegasus if loaded, otherwise use current pipe
439
  refine_pipe = _SUMMARIZER_CACHE.get("pegasus") or summarizer_pipe
440
+ final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice, 3)
441
+ final = refine_combined(refine_pipe, chunk_summaries, tone_choice, final_target_sentences)
442
 
443
+ # bullet postprocess
444
  if tone_choice == "bullet":
445
  parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
446
  bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
447
  final = "\n".join(bullets[:20])
448
 
449
+ elapsed = time.time() - t0
450
  meta = {
451
  "length_choice": length_choice,
452
  "tone": tone_choice,
 
453
  "model_used": model_key,
454
  "chunks": len(chunks),
455
  "input_words": words_len,
456
  "time_seconds": round(elapsed, 2),
457
+ "device": "cpu"
 
 
 
 
458
  }
 
459
  return jsonify({"summary": final, "meta": meta})
460
 
461
  # -------------------------
462
+ # Local run (safe)
463
  # -------------------------
464
  if __name__ == "__main__":
465
+ # For local testing you may call preload_models_at_startup manually or use /preload.
466
  app.run(host="0.0.0.0", port=7860, debug=False)