Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# app.py
|
| 2 |
-
# From Talk to Task —
|
| 3 |
# Model: swiss-ai/Apertus-8B-Instruct-2509
|
| 4 |
-
# Few-shot
|
| 5 |
-
#
|
| 6 |
|
| 7 |
import os
|
| 8 |
import re
|
|
@@ -46,28 +46,20 @@ CONTEXT_GUIDE = (
|
|
| 46 |
"- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
|
| 47 |
)
|
| 48 |
|
| 49 |
-
# Few-shot: exactly
|
| 50 |
FEW_SHOTS = [
|
| 51 |
# EN
|
| 52 |
-
{
|
| 53 |
-
|
| 54 |
-
"labels": ["schedule_meeting"]
|
| 55 |
-
},
|
| 56 |
# FR
|
| 57 |
-
{
|
| 58 |
-
|
| 59 |
-
"labels": ["update_contact_info_non_postal"]
|
| 60 |
-
},
|
| 61 |
# DE
|
| 62 |
-
{
|
| 63 |
-
|
| 64 |
-
"labels": ["update_contact_info_postal_address"]
|
| 65 |
-
},
|
| 66 |
# IT
|
| 67 |
-
{
|
| 68 |
-
|
| 69 |
-
"labels": ["update_kyc_total_assets"]
|
| 70 |
-
},
|
| 71 |
]
|
| 72 |
|
| 73 |
# --------------------- WRITABLE HF CACHE -----------------------------
|
|
@@ -101,6 +93,27 @@ RE_DISCLAIMER = re.compile(r"^\s*disclaimer\s*:", re.IGNORECASE)
|
|
| 101 |
RE_DROP = re.compile(r"(readme|terms|synthetic transcript)", re.IGNORECASE)
|
| 102 |
SMALLTALK_RE = re.compile(r"\b(thanks?|merci|grazie|danke|tsch(ü|u)ss|ciao|bye|ok(ay)?)\b", re.IGNORECASE)
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def _json_from_text(text: str) -> str:
|
| 105 |
s = text.strip()
|
| 106 |
if s.startswith("{") and s.endswith("}"):
|
|
@@ -114,8 +127,7 @@ def safe_json_labels(s: str, allowed: List[str]) -> List[str]:
|
|
| 114 |
except Exception:
|
| 115 |
return []
|
| 116 |
labels = data.get("labels", [])
|
| 117 |
-
clean = []
|
| 118 |
-
seen = set()
|
| 119 |
for lab in labels:
|
| 120 |
if lab in allowed and lab not in seen:
|
| 121 |
clean.append(lab); seen.add(lab)
|
|
@@ -221,6 +233,44 @@ def card_markdown(title: str, value: str, hint: str = "") -> str:
|
|
| 221 |
</div>
|
| 222 |
"""
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
# -------------------------- MODEL -----------------------------------
|
| 225 |
|
| 226 |
class HFModel:
|
|
@@ -264,10 +314,10 @@ class HFModel:
|
|
| 264 |
self.model = self.model.to(DEVICE)
|
| 265 |
|
| 266 |
@torch.inference_mode()
|
| 267 |
-
def generate_json(self, prompt: str, max_new_tokens=48, allow_sampling=False) -> Tuple[str, Dict[str, int]]:
|
| 268 |
"""
|
| 269 |
Deterministic by default. If allow_sampling=True (toggle), we use mild temperature.
|
| 270 |
-
Returns (json_text, token_stats)
|
| 271 |
"""
|
| 272 |
tok = self.tokenizer
|
| 273 |
mdl = self.model
|
|
@@ -282,12 +332,13 @@ class HFModel:
|
|
| 282 |
eos_token_id=tok.eos_token_id,
|
| 283 |
)
|
| 284 |
if allow_sampling:
|
| 285 |
-
# mild sampling; disabled by default to avoid CUDA multinomial issues on T4
|
| 286 |
kwargs.update(dict(do_sample=True, temperature=0.25, top_p=0.9))
|
| 287 |
else:
|
| 288 |
kwargs.update(dict(do_sample=False, temperature=0.0, top_p=1.0))
|
| 289 |
|
|
|
|
| 290 |
out = mdl.generate(**inputs, **kwargs)
|
|
|
|
| 291 |
|
| 292 |
prompt_tokens = int(inputs.input_ids.shape[-1])
|
| 293 |
output_tokens = int(out.shape[-1] - inputs.input_ids.shape[-1])
|
|
@@ -299,7 +350,7 @@ class HFModel:
|
|
| 299 |
"prompt_tokens": prompt_tokens,
|
| 300 |
"output_tokens": output_tokens,
|
| 301 |
"total_tokens": total_tokens,
|
| 302 |
-
}
|
| 303 |
|
| 304 |
_MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
|
| 305 |
|
|
@@ -323,6 +374,19 @@ def preprocess_text(txt: str, add_header: bool, strip_smalltalk: bool) -> str:
|
|
| 323 |
cleaned = "\n".join(lines[-32768:])
|
| 324 |
return f"[EMAIL/MESSAGE SIGNAL]\n{cleaned}" if add_header else cleaned
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
def run_single(
|
| 327 |
custom_repo_id: str,
|
| 328 |
rules_json: Optional[gr.File],
|
|
@@ -341,45 +405,41 @@ def run_single(
|
|
| 341 |
):
|
| 342 |
"""Returns: repo, revision, predicted_json, metric_cards_md, diag_cards_md, raw_metrics_json"""
|
| 343 |
|
|
|
|
|
|
|
| 344 |
repo = (custom_repo_id or DEFAULT_REPO).strip()
|
| 345 |
revision = "main"
|
| 346 |
allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
|
| 347 |
|
| 348 |
-
# Preprocess + cap
|
| 349 |
-
|
| 350 |
if preprocess:
|
| 351 |
transcript = preprocess_text(transcript, add_header, strip_smalltalk)
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
cap_info = ""
|
| 355 |
-
if soft_token_cap and soft_token_cap > 0:
|
| 356 |
-
approx_chars = int(soft_token_cap * 4)
|
| 357 |
-
if len(transcript) > approx_chars:
|
| 358 |
-
transcript = transcript[-approx_chars:]
|
| 359 |
-
cap_info = f"(soft cap ~{soft_token_cap}t)"
|
| 360 |
|
| 361 |
# Build prompt
|
| 362 |
system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
|
| 363 |
-
prompt = build_prompt(system, context_text or CONTEXT_GUIDE,
|
| 364 |
|
| 365 |
model = get_model(repo, revision, load_in_4bit)
|
| 366 |
|
| 367 |
-
# Deterministic pass only
|
| 368 |
-
|
| 369 |
-
raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
|
| 370 |
pred_labels = safe_json_labels(raw_json, allowed)
|
| 371 |
|
| 372 |
# Optional fallback sampling (OFF by default)
|
| 373 |
fallback_used = False
|
| 374 |
if enable_fallback_sampling and not pred_labels:
|
| 375 |
-
raw_json2, tok_stats2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
|
| 376 |
pred_labels2 = safe_json_labels(raw_json2, allowed)
|
| 377 |
if pred_labels2:
|
| 378 |
pred_labels = pred_labels2
|
| 379 |
tok_stats = tok_stats2
|
|
|
|
| 380 |
fallback_used = True
|
| 381 |
|
| 382 |
-
total_latency = time.perf_counter() -
|
| 383 |
est_cost = (total_latency / 3600.0) * max(0.0, float(hourly_rate or 0.0))
|
| 384 |
|
| 385 |
# Ground truth
|
|
@@ -406,16 +466,17 @@ def run_single(
|
|
| 406 |
metric_cards += card("Missing labels", json.dumps(missing, ensure_ascii=False) if gt_labels is not None else "—", "Expected but not predicted")
|
| 407 |
metric_cards += card("Extra labels", json.dumps(extra, ensure_ascii=False) if gt_labels is not None else "—", "Predicted but not expected")
|
| 408 |
|
| 409 |
-
# Diagnostics cards
|
| 410 |
diag_cards = ""
|
| 411 |
diag_cards += card("Model / Rev", f"{repo} / {revision}")
|
| 412 |
diag_cards += card("Device", f"{DEVICE} ({GPU_NAME})")
|
| 413 |
diag_cards += card("Precision dtype", f"{DTYPE_FALLBACK}")
|
| 414 |
diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
|
| 415 |
diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
|
| 416 |
-
diag_cards += card("Effective text length", f"{
|
| 417 |
diag_cards += card("Tokens", f"prompt={tok_stats['prompt_tokens']}, output={tok_stats['output_tokens']}, total={tok_stats['total_tokens']}", "Token counts influence latency & cost")
|
| 418 |
-
diag_cards += card("
|
|
|
|
| 419 |
diag_cards += card("Cost (est.)", f"${(est_cost):.6f} @ {hourly_rate:.4f}/hr")
|
| 420 |
diag_cards += card("Fallback sampling used", "Yes" if fallback_used else "No", "Sampling can be slower/unstable on T4; off by default")
|
| 421 |
|
|
@@ -431,9 +492,11 @@ def run_single(
|
|
| 431 |
"extra": extra if gt_labels is not None else None,
|
| 432 |
"per_label": per_label if gt_labels is not None else None,
|
| 433 |
"token_stats": tok_stats,
|
| 434 |
-
"
|
|
|
|
| 435 |
"estimated_cost_usd": round(est_cost, 6),
|
| 436 |
"fallback_used": fallback_used,
|
|
|
|
| 437 |
}
|
| 438 |
|
| 439 |
return (
|
|
@@ -475,10 +538,11 @@ def run_batch(
|
|
| 475 |
model = get_model(repo, revision, load_in_4bit)
|
| 476 |
|
| 477 |
rows = [["filename","labels"]]
|
| 478 |
-
per_sample_rows = [["filename","pred_labels","gold_labels","precision","recall","f1","exact_match","hamming_loss","missing","extra"]]
|
| 479 |
totals = {"tp":0,"fp":0,"fn":0,"pred_total":0,"gold_total":0}
|
| 480 |
label_global = {lab: {"tp":0,"fp":0,"fn":0} for lab in allowed}
|
| 481 |
-
total_prompt_tokens = 0; total_output_tokens = 0;
|
|
|
|
| 482 |
|
| 483 |
system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
|
| 484 |
|
|
@@ -488,29 +552,28 @@ def run_batch(
|
|
| 488 |
except Exception:
|
| 489 |
rows.append([name, "[] # unreadable"]); continue
|
| 490 |
|
|
|
|
|
|
|
| 491 |
if preprocess:
|
| 492 |
txt = preprocess_text(txt, add_header, strip_smalltalk)
|
|
|
|
| 493 |
|
| 494 |
-
|
| 495 |
-
approx_chars = int(soft_token_cap * 4)
|
| 496 |
-
if len(txt) > approx_chars:
|
| 497 |
-
txt = txt[-approx_chars:]
|
| 498 |
-
|
| 499 |
-
prompt = build_prompt(system, context_text or CONTEXT_GUIDE, txt, allowed, use_fewshot)
|
| 500 |
|
| 501 |
-
|
| 502 |
-
raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
|
| 503 |
pred = safe_json_labels(raw_json, allowed)
|
| 504 |
if enable_fallback_sampling and not pred:
|
| 505 |
-
raw_json2, tok_stats2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
|
| 506 |
pred2 = safe_json_labels(raw_json2, allowed)
|
| 507 |
if pred2:
|
| 508 |
-
pred = pred2
|
| 509 |
-
|
|
|
|
| 510 |
|
| 511 |
-
total_secs += (time.perf_counter() - t0)
|
| 512 |
total_prompt_tokens += tok_stats["prompt_tokens"]
|
| 513 |
total_output_tokens += tok_stats["output_tokens"]
|
|
|
|
|
|
|
| 514 |
n += 1
|
| 515 |
|
| 516 |
rows.append([name, json.dumps(pred, ensure_ascii=False)])
|
|
@@ -538,13 +601,21 @@ def run_batch(
|
|
| 538 |
round(ham,4),
|
| 539 |
json.dumps(missing, ensure_ascii=False),
|
| 540 |
json.dumps(extra, ensure_ascii=False),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
])
|
| 542 |
|
| 543 |
tp, fp, fn = totals["tp"], totals["fp"], totals["fn"]
|
| 544 |
prec = tp / (tp + fp) if (tp + fp) else 0.0
|
| 545 |
rec = tp / (tp + fn) if (tp + fn) else 0.0
|
| 546 |
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
|
| 547 |
-
est_cost = (
|
| 548 |
|
| 549 |
coverage = {lab: 0 for lab in allowed}
|
| 550 |
for r in rows[1:]:
|
|
@@ -572,8 +643,10 @@ def run_batch(
|
|
| 572 |
"avg_prompt_tokens": round(total_prompt_tokens / n, 2) if n else 0.0,
|
| 573 |
"avg_output_tokens": round(total_output_tokens / n, 2) if n else 0.0,
|
| 574 |
},
|
| 575 |
-
"
|
| 576 |
-
"
|
|
|
|
|
|
|
| 577 |
"estimated_cost_usd": round(est_cost, 6),
|
| 578 |
}
|
| 579 |
|
|
@@ -587,7 +660,8 @@ def run_batch(
|
|
| 587 |
diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
|
| 588 |
diag_cards += card("Files processed", f"{n} (with GT: {with_gt})")
|
| 589 |
diag_cards += card("Tokens (totals)", f"prompt={total_prompt_tokens}, output={total_output_tokens}")
|
| 590 |
-
diag_cards += card("Latency", f"total={summary['
|
|
|
|
| 591 |
diag_cards += card("Cost (est.)", f"${summary['estimated_cost_usd']} @ {hourly_rate:.4f}/hr")
|
| 592 |
diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
|
| 593 |
|
|
@@ -612,13 +686,17 @@ def run_batch(
|
|
| 612 |
|
| 613 |
# ----------------------------- UI -----------------------------------
|
| 614 |
|
| 615 |
-
with gr.Blocks(title="From Talk to Task —
|
| 616 |
gr.Markdown(
|
| 617 |
f"""
|
| 618 |
# From Talk to Task — Accuracy & Diagnostics (EN/FR/DE/IT)
|
| 619 |
|
| 620 |
**Default model:** `{DEFAULT_REPO}` (GPU + 4-bit recommended).
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
Upload a **Rules JSON** (`{{"labels":[...]}}`) to override allowed labels.
|
| 623 |
|
| 624 |
**Model output schema:** `{{"labels": [...]}}`
|
|
@@ -640,7 +718,7 @@ with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics (stable)") as
|
|
| 640 |
context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
|
| 641 |
|
| 642 |
with gr.Row():
|
| 643 |
-
soft_cap = gr.Slider(512, 32768, value=1024, step=1, label="Soft token cap (approx)")
|
| 644 |
preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
|
| 645 |
add_header = gr.Checkbox(value=True, label="Add cues header")
|
| 646 |
strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
|
|
|
|
| 1 |
# app.py
|
| 2 |
+
# From Talk to Task — Windowed extraction + two latency measures
|
| 3 |
# Model: swiss-ai/Apertus-8B-Instruct-2509
|
| 4 |
+
# Few-shot: 1 each EN/FR/DE/IT; deterministic by default; optional sampling fallback toggle.
|
| 5 |
+
# Soft token cap: 1024 by default. CUDA fp16 + optional 4-bit. GT scoring + downloads.
|
| 6 |
|
| 7 |
import os
|
| 8 |
import re
|
|
|
|
| 46 |
"- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
|
| 47 |
)
|
| 48 |
|
| 49 |
+
# Few-shot: exactly one per language (compact)
|
| 50 |
FEW_SHOTS = [
|
| 51 |
# EN
|
| 52 |
+
{"transcript": "Agent: Can we meet Friday 3pm on Teams?\nClient: Yes, Friday 3pm works.\nAgent: I’ll send the invite.",
|
| 53 |
+
"labels": ["schedule_meeting"]},
|
|
|
|
|
|
|
| 54 |
# FR
|
| 55 |
+
{"transcript": "Client: Mon numéro a changé: +41 44 000 00 00.\nConseiller: Merci, je mets à jour vos coordonnées.",
|
| 56 |
+
"labels": ["update_contact_info_non_postal"]},
|
|
|
|
|
|
|
| 57 |
# DE
|
| 58 |
+
{"transcript": "Kunde: Neue Postadresse: Musterstrasse 1, 8000 Zürich.\nBerater: Danke, ich aktualisiere die Postadresse.",
|
| 59 |
+
"labels": ["update_contact_info_postal_address"]},
|
|
|
|
|
|
|
| 60 |
# IT
|
| 61 |
+
{"transcript": "Cliente: Totale patrimonio confermato a 8 milioni CHF.\nConsulente: Aggiorno i dati KYC sul totale degli asset.",
|
| 62 |
+
"labels": ["update_kyc_total_assets"]},
|
|
|
|
|
|
|
| 63 |
]
|
| 64 |
|
| 65 |
# --------------------- WRITABLE HF CACHE -----------------------------
|
|
|
|
| 93 |
RE_DROP = re.compile(r"(readme|terms|synthetic transcript)", re.IGNORECASE)
|
| 94 |
SMALLTALK_RE = re.compile(r"\b(thanks?|merci|grazie|danke|tsch(ü|u)ss|ciao|bye|ok(ay)?)\b", re.IGNORECASE)
|
| 95 |
|
| 96 |
+
# keyword windows (EN/FR/DE/IT) — expand as needed
|
| 97 |
+
WINDOW_KEYWORDS = [
|
| 98 |
+
# meeting / schedule
|
| 99 |
+
r"\b(meet|meeting|schedule|appointment|teams|zoom|google meet|calendar)\b",
|
| 100 |
+
r"\b(rendez[- ]?vous|réunion|planifier|calendrier|teams|zoom)\b",
|
| 101 |
+
r"\b(termin|treffen|besprechung|kalender|teams|zoom)\b",
|
| 102 |
+
r"\b(appuntamento|riunione|calendario|teams|zoom)\b",
|
| 103 |
+
# address / phone / email
|
| 104 |
+
r"\b(address|street|avenue|road|postcode|phone|email)\b",
|
| 105 |
+
r"\b(adresse|rue|avenue|code postal|téléphone|courriel|email)\b",
|
| 106 |
+
r"\b(adresse|straße|strasse|plz|telefon|e-?mail)\b",
|
| 107 |
+
r"\b(indirizzo|via|cap|telefono|e-?mail)\b",
|
| 108 |
+
# KYC assets / totals / origin / purpose
|
| 109 |
+
r"\b(total assets|net worth|portfolio|real estate|origin of assets|source of wealth|purpose of relationship)\b",
|
| 110 |
+
r"\b(actifs totaux|patrimoine|immobilier|origine des fonds|source de richesse|but de la relation)\b",
|
| 111 |
+
r"\b(gesamtverm(ö|o)gen|verm(ö|o)gen|immobilien|herkunft der verm(ö|o)genswerte|zweck der gesch(ä|a)ftsbeziehung)\b",
|
| 112 |
+
r"\b(patrimonio totale|immobiliare|origine dei fondi|scopo della relazione)\b",
|
| 113 |
+
r"\b(chf|eur|usd|cur[13]|francs?)\b",
|
| 114 |
+
r"\b(\d{1,3}([.'’ ]\d{3})*(,\d+)?)(\s?(chf|eur|usd))\b",
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
def _json_from_text(text: str) -> str:
|
| 118 |
s = text.strip()
|
| 119 |
if s.startswith("{") and s.endswith("}"):
|
|
|
|
| 127 |
except Exception:
|
| 128 |
return []
|
| 129 |
labels = data.get("labels", [])
|
| 130 |
+
clean, seen = [], set()
|
|
|
|
| 131 |
for lab in labels:
|
| 132 |
if lab in allowed and lab not in seen:
|
| 133 |
clean.append(lab); seen.add(lab)
|
|
|
|
| 233 |
</div>
|
| 234 |
"""
|
| 235 |
|
| 236 |
+
# ------------------- WINDOWED EXTRACTION (fix for empty labels) -------------------
|
| 237 |
+
|
| 238 |
+
def extract_windows(text: str, max_windows: int = 6, half_span_lines: int = 3) -> str:
|
| 239 |
+
"""
|
| 240 |
+
Find up to `max_windows` windows around keyword hits; each window is ±`half_span_lines` lines.
|
| 241 |
+
If no hits, return the FIRST 8k characters instead of last chunk (common cause of misses).
|
| 242 |
+
"""
|
| 243 |
+
lines = text.splitlines()
|
| 244 |
+
n = len(lines)
|
| 245 |
+
# collect hit line indices
|
| 246 |
+
hits: List[int] = []
|
| 247 |
+
pattern = re.compile("|".join(WINDOW_KEYWORDS), re.IGNORECASE)
|
| 248 |
+
for i, ln in enumerate(lines):
|
| 249 |
+
if pattern.search(ln):
|
| 250 |
+
hits.append(i)
|
| 251 |
+
# de-duplicate and cap
|
| 252 |
+
unique_hits = []
|
| 253 |
+
seen = set()
|
| 254 |
+
for idx in hits:
|
| 255 |
+
# bucket nearby hits to avoid redundant windows
|
| 256 |
+
bucket = idx // 2 # coarse bucketing
|
| 257 |
+
if bucket not in seen:
|
| 258 |
+
seen.add(bucket)
|
| 259 |
+
unique_hits.append(idx)
|
| 260 |
+
unique_hits = unique_hits[:max_windows]
|
| 261 |
+
|
| 262 |
+
if not unique_hits:
|
| 263 |
+
# return the opening chunk; most KYC/context often appears early
|
| 264 |
+
return "\n".join(lines[: min(2000, n)])
|
| 265 |
+
|
| 266 |
+
# Build windows and merge
|
| 267 |
+
windows = []
|
| 268 |
+
for idx in unique_hits:
|
| 269 |
+
a = max(0, idx - half_span_lines)
|
| 270 |
+
b = min(n, idx + half_span_lines + 1)
|
| 271 |
+
windows.append("\n".join(lines[a:b]))
|
| 272 |
+
return "\n...\n".join(windows)
|
| 273 |
+
|
| 274 |
# -------------------------- MODEL -----------------------------------
|
| 275 |
|
| 276 |
class HFModel:
|
|
|
|
| 314 |
self.model = self.model.to(DEVICE)
|
| 315 |
|
| 316 |
@torch.inference_mode()
|
| 317 |
+
def generate_json(self, prompt: str, max_new_tokens=48, allow_sampling=False) -> Tuple[str, Dict[str, int], float]:
|
| 318 |
"""
|
| 319 |
Deterministic by default. If allow_sampling=True (toggle), we use mild temperature.
|
| 320 |
+
Returns (json_text, token_stats, model_latency_seconds)
|
| 321 |
"""
|
| 322 |
tok = self.tokenizer
|
| 323 |
mdl = self.model
|
|
|
|
| 332 |
eos_token_id=tok.eos_token_id,
|
| 333 |
)
|
| 334 |
if allow_sampling:
|
|
|
|
| 335 |
kwargs.update(dict(do_sample=True, temperature=0.25, top_p=0.9))
|
| 336 |
else:
|
| 337 |
kwargs.update(dict(do_sample=False, temperature=0.0, top_p=1.0))
|
| 338 |
|
| 339 |
+
t0 = time.perf_counter()
|
| 340 |
out = mdl.generate(**inputs, **kwargs)
|
| 341 |
+
model_latency = time.perf_counter() - t0
|
| 342 |
|
| 343 |
prompt_tokens = int(inputs.input_ids.shape[-1])
|
| 344 |
output_tokens = int(out.shape[-1] - inputs.input_ids.shape[-1])
|
|
|
|
| 350 |
"prompt_tokens": prompt_tokens,
|
| 351 |
"output_tokens": output_tokens,
|
| 352 |
"total_tokens": total_tokens,
|
| 353 |
+
}, model_latency
|
| 354 |
|
| 355 |
_MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
|
| 356 |
|
|
|
|
| 374 |
cleaned = "\n".join(lines[-32768:])
|
| 375 |
return f"[EMAIL/MESSAGE SIGNAL]\n{cleaned}" if add_header else cleaned
|
| 376 |
|
| 377 |
+
def window_then_cap(text: str, soft_token_cap: int) -> Tuple[str, str]:
|
| 378 |
+
"""
|
| 379 |
+
Apply keyword windowing; then hard cap by approximate chars (~4 chars/token).
|
| 380 |
+
Returns (final_text, info_string).
|
| 381 |
+
"""
|
| 382 |
+
windowed = extract_windows(text)
|
| 383 |
+
approx_chars = int(max(soft_token_cap, 0) * 4) if soft_token_cap else 0
|
| 384 |
+
info = "windowed"
|
| 385 |
+
if approx_chars and len(windowed) > approx_chars:
|
| 386 |
+
windowed = windowed[:approx_chars]
|
| 387 |
+
info = f"windowed + soft cap ~{soft_token_cap}t"
|
| 388 |
+
return windowed, info
|
| 389 |
+
|
| 390 |
def run_single(
|
| 391 |
custom_repo_id: str,
|
| 392 |
rules_json: Optional[gr.File],
|
|
|
|
| 405 |
):
|
| 406 |
"""Returns: repo, revision, predicted_json, metric_cards_md, diag_cards_md, raw_metrics_json"""
|
| 407 |
|
| 408 |
+
total_t0 = time.perf_counter() # TOTAL latency starts here
|
| 409 |
+
|
| 410 |
repo = (custom_repo_id or DEFAULT_REPO).strip()
|
| 411 |
revision = "main"
|
| 412 |
allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
|
| 413 |
|
| 414 |
+
# Preprocess + window + cap
|
| 415 |
+
effective_len_before = len(transcript)
|
| 416 |
if preprocess:
|
| 417 |
transcript = preprocess_text(transcript, add_header, strip_smalltalk)
|
| 418 |
+
windowed, cap_info = window_then_cap(transcript, soft_token_cap)
|
| 419 |
+
effective_len_after = len(windowed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
# Build prompt
|
| 422 |
system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
|
| 423 |
+
prompt = build_prompt(system, context_text or CONTEXT_GUIDE, windowed, allowed, use_fewshot)
|
| 424 |
|
| 425 |
model = get_model(repo, revision, load_in_4bit)
|
| 426 |
|
| 427 |
+
# Deterministic pass only
|
| 428 |
+
raw_json, tok_stats, model_latency = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
|
|
|
|
| 429 |
pred_labels = safe_json_labels(raw_json, allowed)
|
| 430 |
|
| 431 |
# Optional fallback sampling (OFF by default)
|
| 432 |
fallback_used = False
|
| 433 |
if enable_fallback_sampling and not pred_labels:
|
| 434 |
+
raw_json2, tok_stats2, model_latency2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
|
| 435 |
pred_labels2 = safe_json_labels(raw_json2, allowed)
|
| 436 |
if pred_labels2:
|
| 437 |
pred_labels = pred_labels2
|
| 438 |
tok_stats = tok_stats2
|
| 439 |
+
model_latency = model_latency2
|
| 440 |
fallback_used = True
|
| 441 |
|
| 442 |
+
total_latency = time.perf_counter() - total_t0
|
| 443 |
est_cost = (total_latency / 3600.0) * max(0.0, float(hourly_rate or 0.0))
|
| 444 |
|
| 445 |
# Ground truth
|
|
|
|
| 466 |
metric_cards += card("Missing labels", json.dumps(missing, ensure_ascii=False) if gt_labels is not None else "—", "Expected but not predicted")
|
| 467 |
metric_cards += card("Extra labels", json.dumps(extra, ensure_ascii=False) if gt_labels is not None else "—", "Predicted but not expected")
|
| 468 |
|
| 469 |
+
# Diagnostics cards — now with TWO latency measures
|
| 470 |
diag_cards = ""
|
| 471 |
diag_cards += card("Model / Rev", f"{repo} / {revision}")
|
| 472 |
diag_cards += card("Device", f"{DEVICE} ({GPU_NAME})")
|
| 473 |
diag_cards += card("Precision dtype", f"{DTYPE_FALLBACK}")
|
| 474 |
diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
|
| 475 |
diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
|
| 476 |
+
diag_cards += card("Effective text length", f"before={effective_len_before} chars → after={effective_len_after} ({cap_info})")
|
| 477 |
diag_cards += card("Tokens", f"prompt={tok_stats['prompt_tokens']}, output={tok_stats['output_tokens']}, total={tok_stats['total_tokens']}", "Token counts influence latency & cost")
|
| 478 |
+
diag_cards += card("Model latency", f"{model_latency:.2f} s", "Time spent in model.generate(...) only")
|
| 479 |
+
diag_cards += card("Total latency", f"{total_latency:.2f} s", "End-to-end time (preprocess → model → postprocess)")
|
| 480 |
diag_cards += card("Cost (est.)", f"${(est_cost):.6f} @ {hourly_rate:.4f}/hr")
|
| 481 |
diag_cards += card("Fallback sampling used", "Yes" if fallback_used else "No", "Sampling can be slower/unstable on T4; off by default")
|
| 482 |
|
|
|
|
| 492 |
"extra": extra if gt_labels is not None else None,
|
| 493 |
"per_label": per_label if gt_labels is not None else None,
|
| 494 |
"token_stats": tok_stats,
|
| 495 |
+
"model_latency_seconds": round(model_latency, 3),
|
| 496 |
+
"total_latency_seconds": round(total_latency, 3),
|
| 497 |
"estimated_cost_usd": round(est_cost, 6),
|
| 498 |
"fallback_used": fallback_used,
|
| 499 |
+
"cap_info": cap_info,
|
| 500 |
}
|
| 501 |
|
| 502 |
return (
|
|
|
|
| 538 |
model = get_model(repo, revision, load_in_4bit)
|
| 539 |
|
| 540 |
rows = [["filename","labels"]]
|
| 541 |
+
per_sample_rows = [["filename","pred_labels","gold_labels","precision","recall","f1","exact_match","hamming_loss","missing","extra","model_latency_s","total_latency_s","prompt_tokens","output_tokens"]]
|
| 542 |
totals = {"tp":0,"fp":0,"fn":0,"pred_total":0,"gold_total":0}
|
| 543 |
label_global = {lab: {"tp":0,"fp":0,"fn":0} for lab in allowed}
|
| 544 |
+
total_prompt_tokens = 0; total_output_tokens = 0; sum_model_s = 0.0; sum_total_s = 0.0
|
| 545 |
+
n=0; with_gt=0
|
| 546 |
|
| 547 |
system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
|
| 548 |
|
|
|
|
| 552 |
except Exception:
|
| 553 |
rows.append([name, "[] # unreadable"]); continue
|
| 554 |
|
| 555 |
+
total_t0 = time.perf_counter() # TOTAL latency per file
|
| 556 |
+
|
| 557 |
if preprocess:
|
| 558 |
txt = preprocess_text(txt, add_header, strip_smalltalk)
|
| 559 |
+
txt_windowed, cap_info = window_then_cap(txt, soft_token_cap)
|
| 560 |
|
| 561 |
+
prompt = build_prompt(system, context_text or CONTEXT_GUIDE, txt_windowed, allowed, use_fewshot)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
+
raw_json, tok_stats, model_latency = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
|
|
|
|
| 564 |
pred = safe_json_labels(raw_json, allowed)
|
| 565 |
if enable_fallback_sampling and not pred:
|
| 566 |
+
raw_json2, tok_stats2, model_latency2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
|
| 567 |
pred2 = safe_json_labels(raw_json2, allowed)
|
| 568 |
if pred2:
|
| 569 |
+
pred = pred2; tok_stats = tok_stats2; model_latency = model_latency2
|
| 570 |
+
|
| 571 |
+
total_latency = time.perf_counter() - total_t0
|
| 572 |
|
|
|
|
| 573 |
total_prompt_tokens += tok_stats["prompt_tokens"]
|
| 574 |
total_output_tokens += tok_stats["output_tokens"]
|
| 575 |
+
sum_model_s += model_latency
|
| 576 |
+
sum_total_s += total_latency
|
| 577 |
n += 1
|
| 578 |
|
| 579 |
rows.append([name, json.dumps(pred, ensure_ascii=False)])
|
|
|
|
| 601 |
round(ham,4),
|
| 602 |
json.dumps(missing, ensure_ascii=False),
|
| 603 |
json.dumps(extra, ensure_ascii=False),
|
| 604 |
+
round(model_latency,3), round(total_latency,3),
|
| 605 |
+
tok_stats["prompt_tokens"], tok_stats["output_tokens"],
|
| 606 |
+
])
|
| 607 |
+
else:
|
| 608 |
+
per_sample_rows.append([
|
| 609 |
+
name, json.dumps(pred, ensure_ascii=False), None, None, None, None, None, None, None, None,
|
| 610 |
+
round(model_latency,3), round(total_latency,3),
|
| 611 |
+
tok_stats["prompt_tokens"], tok_stats["output_tokens"],
|
| 612 |
])
|
| 613 |
|
| 614 |
tp, fp, fn = totals["tp"], totals["fp"], totals["fn"]
|
| 615 |
prec = tp / (tp + fp) if (tp + fp) else 0.0
|
| 616 |
rec = tp / (tp + fn) if (tp + fn) else 0.0
|
| 617 |
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
|
| 618 |
+
est_cost = (sum_total_s / 3600.0) * max(0.0, float(hourly_rate or 0.0))
|
| 619 |
|
| 620 |
coverage = {lab: 0 for lab in allowed}
|
| 621 |
for r in rows[1:]:
|
|
|
|
| 643 |
"avg_prompt_tokens": round(total_prompt_tokens / n, 2) if n else 0.0,
|
| 644 |
"avg_output_tokens": round(total_output_tokens / n, 2) if n else 0.0,
|
| 645 |
},
|
| 646 |
+
"latency_seconds_model_total": round(sum_model_s, 3),
|
| 647 |
+
"latency_seconds_total": round(sum_total_s, 3),
|
| 648 |
+
"avg_model_latency_seconds": round(sum_model_s / n, 3) if n else 0.0,
|
| 649 |
+
"avg_total_latency_seconds": round(sum_total_s / n, 3) if n else 0.0,
|
| 650 |
"estimated_cost_usd": round(est_cost, 6),
|
| 651 |
}
|
| 652 |
|
|
|
|
| 660 |
diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
|
| 661 |
diag_cards += card("Files processed", f"{n} (with GT: {with_gt})")
|
| 662 |
diag_cards += card("Tokens (totals)", f"prompt={total_prompt_tokens}, output={total_output_tokens}")
|
| 663 |
+
diag_cards += card("Latency (model)", f"total={summary['latency_seconds_model_total']} s, avg={summary['avg_model_latency_seconds']} s")
|
| 664 |
+
diag_cards += card("Latency (total)", f"total={summary['latency_seconds_total']} s, avg={summary['avg_total_latency_seconds']} s")
|
| 665 |
diag_cards += card("Cost (est.)", f"${summary['estimated_cost_usd']} @ {hourly_rate:.4f}/hr")
|
| 666 |
diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
|
| 667 |
|
|
|
|
| 686 |
|
| 687 |
# ----------------------------- UI -----------------------------------
|
| 688 |
|
| 689 |
+
with gr.Blocks(title="From Talk to Task — Windowed + Two Latencies") as demo:
|
| 690 |
gr.Markdown(
|
| 691 |
f"""
|
| 692 |
# From Talk to Task — Accuracy & Diagnostics (EN/FR/DE/IT)
|
| 693 |
|
| 694 |
**Default model:** `{DEFAULT_REPO}` (GPU + 4-bit recommended).
|
| 695 |
+
Now includes **keyword windowing** (keeps early cues) and **two latency measures**:
|
| 696 |
+
- **Model latency:** time spent inside the model generate call
|
| 697 |
+
- **Total latency:** end-to-end time (preprocess → model → postprocess)
|
| 698 |
+
|
| 699 |
+
Upload ground truth to compute **Precision / Recall / F1 / Exact match / Hamming loss**.
|
| 700 |
Upload a **Rules JSON** (`{{"labels":[...]}}`) to override allowed labels.
|
| 701 |
|
| 702 |
**Model output schema:** `{{"labels": [...]}}`
|
|
|
|
| 718 |
context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
|
| 719 |
|
| 720 |
with gr.Row():
|
| 721 |
+
soft_cap = gr.Slider(512, 32768, value=1024, step=1, label="Soft token cap (approx; applied after keyword windows)")
|
| 722 |
preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
|
| 723 |
add_header = gr.Checkbox(value=True, label="Add cues header")
|
| 724 |
strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
|