sayanAIAI commited on
Commit
440165f
·
verified ·
1 Parent(s): e3483d8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +77 -46
main.py CHANGED
@@ -96,30 +96,18 @@ def _first_int_from_text(s: str, fallback: Optional[int] = None) -> Optional[int
96
  # -------------------------
97
  # Safe model loading (preload on startup)
98
  # -------------------------
99
- def safe_load_pipeline(model_name: str):
100
- """
101
- Try to load tokenizer & model in robust manner:
102
- - try fast tokenizer
103
- - fallback to use_fast=False
104
- - if still fails, return None (caller should fallback to fallback model)
105
- """
106
  try:
107
- tok = AutoTokenizer.from_pretrained(model_name)
108
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
109
- pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
110
- logger.info("Loaded pipeline for %s (fast tokenizer)", model_name)
111
- return pipe
112
- except Exception as e_fast:
113
- logger.warning("Fast tokenizer load failed for %s: %s. Trying slow tokenizer...", model_name, e_fast)
114
- try:
115
- tok = AutoTokenizer.from_pretrained(model_name, use_fast=False)
116
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
117
- pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
118
- logger.info("Loaded pipeline for %s (slow tokenizer)", model_name)
119
- return pipe
120
- except Exception as e_slow:
121
- logger.exception("Slow tokenizer load failed for %s: %s", model_name, e_slow)
122
- return None
123
 
124
  def preload_models_at_startup():
125
  global _PREFERRED_SUMMARIZER_KEY
@@ -216,44 +204,87 @@ def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) ->
216
  # Param generator (AI decision) with fallback heuristic
217
  # -------------------------
218
  def generate_summarization_config(text: str) -> Dict[str, Any]:
219
- # If parameter generator pipeline loaded, use it; else fallback to heuristic
220
- pg_pipe = _PARAM_GENERATOR
221
- if pg_pipe is None:
 
 
 
 
 
222
  words = len(text.split())
223
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
224
- defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
225
  mn, mx = defaults[length]
226
  return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
227
 
 
 
 
 
 
228
  prompt = (
229
- "Recommend summarization settings. Output JSON exactly like:\n"
230
  '{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
231
  "Text:\n'''"
232
  + text[:3000] + "'''"
233
  )
 
234
  try:
235
- out = pg_pipe(prompt, max_new_tokens=48, do_sample=False, num_beams=1)[0].get("summary_text","")
236
- # some pipelines return 'generated_text' or 'summary_text' depending; try both
 
 
 
 
 
237
  if not out:
238
- out = pg_pipe(prompt, max_new_tokens=48, do_sample=False, num_beams=1)[0].get("generated_text","")
239
- # attempt to extract JSON
240
- j = re.search(r"\{.*\}", out, re.DOTALL)
241
- if j:
242
- cfg = json.loads(j.group().replace("'", '"'))
 
 
 
 
 
 
 
 
243
  else:
 
244
  cfg = None
245
- if not cfg:
246
- raise ValueError("Unparseable param output")
247
- length = cfg.get("length","medium").lower()
248
- tone = cfg.get("tone","neutral").lower()
249
- return {"length": length, "min_length": cfg.get("min_words", 50), "max_length": cfg.get("max_words", 130), "tone": tone}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  except Exception as e:
251
- logger.exception("Param generator failed: %s; using heuristic", e)
252
- words = len(text.split())
253
- length = "short" if words < 150 else ("medium" if words < 800 else "long")
254
- defaults = {"short": (12,50), "medium": (50,130), "long":(130,300)}
255
- mn, mx = defaults[length]
256
- return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
257
 
258
  # -------------------------
259
  # Orchestrator: chunk summarization with threadpool + timeouts
 
96
  # -------------------------
97
  # Safe model loading (preload on startup)
98
  # -------------------------
99
+ def safe_load_param_generator():
100
+ global _PARAM_GENERATOR
 
 
 
 
 
101
  try:
102
+ logger.info("Loading param-generator (text2text) model: %s", PARAM_MODEL)
103
+ p_tok = AutoTokenizer.from_pretrained(PARAM_MODEL)
104
+ p_mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
105
+ # IMPORTANT: use text2text-generation so outputs are in generated_text (not summary_text)
106
+ _PARAM_GENERATOR = pipeline("text2text-generation", model=p_mod, tokenizer=p_tok, device=DEVICE)
107
+ logger.info("Param-generator loaded as text2text-generation.")
108
+ except Exception as e:
109
+ logger.exception("Param-generator failed to load as text2text: %s", e)
110
+ _PARAM_GENERATOR = None
 
 
 
 
 
 
 
111
 
112
  def preload_models_at_startup():
113
  global _PREFERRED_SUMMARIZER_KEY
 
204
  # Param generator (AI decision) with fallback heuristic
205
  # -------------------------
206
  def generate_summarization_config(text: str) -> Dict[str, Any]:
207
+ """
208
+ Uses the text2text param-generator to output a JSON config.
209
+ If the generator fails or returns something noisy (e.g., echoes the input),
210
+ fall back to a safe heuristic.
211
+ """
212
+ defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
213
+ # heuristic fallback
214
+ def fallback():
215
  words = len(text.split())
216
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
 
217
  mn, mx = defaults[length]
218
  return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
219
 
220
+ pg = _PARAM_GENERATOR
221
+ if pg is None:
222
+ logger.info("Param-generator not available; using fallback heuristic.")
223
+ return fallback()
224
+
225
  prompt = (
226
+ "Recommend summarization settings for this text. Answer ONLY with JSON of the form:\n"
227
  '{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
228
  "Text:\n'''"
229
  + text[:3000] + "'''"
230
  )
231
+
232
  try:
233
+ out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1, early_stopping=True)[0]
234
+ # different pipeline versions may return different keys; check both:
235
+ out = out_item.get("generated_text") or out_item.get("summary_text") or out_item.get("text") or ""
236
+ out = (out or "").strip()
237
+
238
+ # COMMON FAILURE MODE: the model just echoes the input — reject that
239
+ # If output contains a long substring of the input, treat as invalid.
240
  if not out:
241
+ raise ValueError("Empty param-generator output")
242
+ # If the returned text contains more than 40% of the original input words, treat it as echo
243
+ input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200])
244
+ out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200])
245
+ if len(input_words) > 0 and (len(input_words & out_words) / max(1, len(input_words))) > 0.4:
246
+ logger.warning("Param-generator appears to echo input; discarding and using heuristic.")
247
+ return fallback()
248
+
249
+ # Find JSON object in output
250
+ jmatch = re.search(r"\{.*\}", out, re.DOTALL)
251
+ if jmatch:
252
+ raw = jmatch.group().replace("'", '"')
253
+ cfg = json.loads(raw)
254
  else:
255
+ # attempt to parse line with key:value pairs (tolerant)
256
  cfg = None
257
+
258
+ if not cfg or not isinstance(cfg, dict):
259
+ logger.warning("Param-generator output not parseable as JSON: %s", out[:300])
260
+ return fallback()
261
+
262
+ length = cfg.get("length", "medium").lower()
263
+ tone = cfg.get("tone", "neutral").lower()
264
+ min_w = cfg.get("min_words") or cfg.get("min_length") or cfg.get("min")
265
+ max_w = cfg.get("max_words") or cfg.get("max_length") or cfg.get("max")
266
+
267
+ if length not in ("short","medium","long"):
268
+ words = len(text.split())
269
+ length = "short" if words < 150 else ("medium" if words < 800 else "long")
270
+ if tone not in ("neutral","formal","casual","bullet"):
271
+ tone = "neutral"
272
+
273
+ defaults_min, defaults_max = defaults.get(length, (50,130))
274
+ try:
275
+ mn = int(min_w) if min_w is not None else defaults_min
276
+ mx = int(max_w) if max_w is not None else defaults_max
277
+ except Exception:
278
+ mn, mx = defaults_min, defaults_max
279
+
280
+ mn = max(5, min(mn, 2000))
281
+ mx = max(mn + 5, min(mx, 4000))
282
+ logger.info("Param-generator suggested length=%s tone=%s min=%s max=%s", length, tone, mn, mx)
283
+ return {"length": length, "min_length": mn, "max_length": mx, "tone": tone}
284
+
285
  except Exception as e:
286
+ logger.exception("Param-generator failed to produce usable config: %s", e)
287
+ return fallback()
 
 
 
 
288
 
289
  # -------------------------
290
  # Orchestrator: chunk summarization with threadpool + timeouts