Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -96,30 +96,18 @@ def _first_int_from_text(s: str, fallback: Optional[int] = None) -> Optional[int
|
|
| 96 |
# -------------------------
|
| 97 |
# Safe model loading (preload on startup)
|
| 98 |
# -------------------------
|
| 99 |
-
def
|
| 100 |
-
|
| 101 |
-
Try to load tokenizer & model in robust manner:
|
| 102 |
-
- try fast tokenizer
|
| 103 |
-
- fallback to use_fast=False
|
| 104 |
-
- if still fails, return None (caller should fallback to fallback model)
|
| 105 |
-
"""
|
| 106 |
try:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 117 |
-
pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
|
| 118 |
-
logger.info("Loaded pipeline for %s (slow tokenizer)", model_name)
|
| 119 |
-
return pipe
|
| 120 |
-
except Exception as e_slow:
|
| 121 |
-
logger.exception("Slow tokenizer load failed for %s: %s", model_name, e_slow)
|
| 122 |
-
return None
|
| 123 |
|
| 124 |
def preload_models_at_startup():
|
| 125 |
global _PREFERRED_SUMMARIZER_KEY
|
|
@@ -216,44 +204,87 @@ def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) ->
|
|
| 216 |
# Param generator (AI decision) with fallback heuristic
|
| 217 |
# -------------------------
|
| 218 |
def generate_summarization_config(text: str) -> Dict[str, Any]:
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
words = len(text.split())
|
| 223 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 224 |
-
defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
|
| 225 |
mn, mx = defaults[length]
|
| 226 |
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
prompt = (
|
| 229 |
-
"Recommend summarization settings.
|
| 230 |
'{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
|
| 231 |
"Text:\n'''"
|
| 232 |
+ text[:3000] + "'''"
|
| 233 |
)
|
|
|
|
| 234 |
try:
|
| 235 |
-
|
| 236 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
if not out:
|
| 238 |
-
|
| 239 |
-
#
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
else:
|
|
|
|
| 244 |
cfg = None
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
except Exception as e:
|
| 251 |
-
logger.exception("Param
|
| 252 |
-
|
| 253 |
-
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 254 |
-
defaults = {"short": (12,50), "medium": (50,130), "long":(130,300)}
|
| 255 |
-
mn, mx = defaults[length]
|
| 256 |
-
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 257 |
|
| 258 |
# -------------------------
|
| 259 |
# Orchestrator: chunk summarization with threadpool + timeouts
|
|
|
|
| 96 |
# -------------------------
|
| 97 |
# Safe model loading (preload on startup)
|
| 98 |
# -------------------------
|
| 99 |
+
def safe_load_param_generator():
|
| 100 |
+
global _PARAM_GENERATOR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
try:
|
| 102 |
+
logger.info("Loading param-generator (text2text) model: %s", PARAM_MODEL)
|
| 103 |
+
p_tok = AutoTokenizer.from_pretrained(PARAM_MODEL)
|
| 104 |
+
p_mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
|
| 105 |
+
# IMPORTANT: use text2text-generation so outputs are in generated_text (not summary_text)
|
| 106 |
+
_PARAM_GENERATOR = pipeline("text2text-generation", model=p_mod, tokenizer=p_tok, device=DEVICE)
|
| 107 |
+
logger.info("Param-generator loaded as text2text-generation.")
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.exception("Param-generator failed to load as text2text: %s", e)
|
| 110 |
+
_PARAM_GENERATOR = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
def preload_models_at_startup():
|
| 113 |
global _PREFERRED_SUMMARIZER_KEY
|
|
|
|
| 204 |
# Param generator (AI decision) with fallback heuristic
|
| 205 |
# -------------------------
|
| 206 |
def generate_summarization_config(text: str) -> Dict[str, Any]:
|
| 207 |
+
"""
|
| 208 |
+
Uses the text2text param-generator to output a JSON config.
|
| 209 |
+
If the generator fails or returns something noisy (e.g., echoes the input),
|
| 210 |
+
fall back to a safe heuristic.
|
| 211 |
+
"""
|
| 212 |
+
defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
|
| 213 |
+
# heuristic fallback
|
| 214 |
+
def fallback():
|
| 215 |
words = len(text.split())
|
| 216 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
|
|
|
| 217 |
mn, mx = defaults[length]
|
| 218 |
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 219 |
|
| 220 |
+
pg = _PARAM_GENERATOR
|
| 221 |
+
if pg is None:
|
| 222 |
+
logger.info("Param-generator not available; using fallback heuristic.")
|
| 223 |
+
return fallback()
|
| 224 |
+
|
| 225 |
prompt = (
|
| 226 |
+
"Recommend summarization settings for this text. Answer ONLY with JSON of the form:\n"
|
| 227 |
'{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
|
| 228 |
"Text:\n'''"
|
| 229 |
+ text[:3000] + "'''"
|
| 230 |
)
|
| 231 |
+
|
| 232 |
try:
|
| 233 |
+
out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1, early_stopping=True)[0]
|
| 234 |
+
# different pipeline versions may return different keys; check both:
|
| 235 |
+
out = out_item.get("generated_text") or out_item.get("summary_text") or out_item.get("text") or ""
|
| 236 |
+
out = (out or "").strip()
|
| 237 |
+
|
| 238 |
+
# COMMON FAILURE MODE: the model just echoes the input — reject that
|
| 239 |
+
# If output contains a long substring of the input, treat as invalid.
|
| 240 |
if not out:
|
| 241 |
+
raise ValueError("Empty param-generator output")
|
| 242 |
+
# If the returned text contains more than 40% of the original input words, treat it as echo
|
| 243 |
+
input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200])
|
| 244 |
+
out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200])
|
| 245 |
+
if len(input_words) > 0 and (len(input_words & out_words) / max(1, len(input_words))) > 0.4:
|
| 246 |
+
logger.warning("Param-generator appears to echo input; discarding and using heuristic.")
|
| 247 |
+
return fallback()
|
| 248 |
+
|
| 249 |
+
# Find JSON object in output
|
| 250 |
+
jmatch = re.search(r"\{.*\}", out, re.DOTALL)
|
| 251 |
+
if jmatch:
|
| 252 |
+
raw = jmatch.group().replace("'", '"')
|
| 253 |
+
cfg = json.loads(raw)
|
| 254 |
else:
|
| 255 |
+
# attempt to parse line with key:value pairs (tolerant)
|
| 256 |
cfg = None
|
| 257 |
+
|
| 258 |
+
if not cfg or not isinstance(cfg, dict):
|
| 259 |
+
logger.warning("Param-generator output not parseable as JSON: %s", out[:300])
|
| 260 |
+
return fallback()
|
| 261 |
+
|
| 262 |
+
length = cfg.get("length", "medium").lower()
|
| 263 |
+
tone = cfg.get("tone", "neutral").lower()
|
| 264 |
+
min_w = cfg.get("min_words") or cfg.get("min_length") or cfg.get("min")
|
| 265 |
+
max_w = cfg.get("max_words") or cfg.get("max_length") or cfg.get("max")
|
| 266 |
+
|
| 267 |
+
if length not in ("short","medium","long"):
|
| 268 |
+
words = len(text.split())
|
| 269 |
+
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 270 |
+
if tone not in ("neutral","formal","casual","bullet"):
|
| 271 |
+
tone = "neutral"
|
| 272 |
+
|
| 273 |
+
defaults_min, defaults_max = defaults.get(length, (50,130))
|
| 274 |
+
try:
|
| 275 |
+
mn = int(min_w) if min_w is not None else defaults_min
|
| 276 |
+
mx = int(max_w) if max_w is not None else defaults_max
|
| 277 |
+
except Exception:
|
| 278 |
+
mn, mx = defaults_min, defaults_max
|
| 279 |
+
|
| 280 |
+
mn = max(5, min(mn, 2000))
|
| 281 |
+
mx = max(mn + 5, min(mx, 4000))
|
| 282 |
+
logger.info("Param-generator suggested length=%s tone=%s min=%s max=%s", length, tone, mn, mx)
|
| 283 |
+
return {"length": length, "min_length": mn, "max_length": mx, "tone": tone}
|
| 284 |
+
|
| 285 |
except Exception as e:
|
| 286 |
+
logger.exception("Param-generator failed to produce usable config: %s", e)
|
| 287 |
+
return fallback()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
# -------------------------
|
| 290 |
# Orchestrator: chunk summarization with threadpool + timeouts
|