RishiRP commited on
Commit
db91bf5
·
verified ·
1 Parent(s): 3e2cf36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -67
app.py CHANGED
@@ -1,8 +1,8 @@
1
  # app.py
2
- # From Talk to Task — Accuracy & Diagnostics (stable / fast)
3
  # Model: swiss-ai/Apertus-8B-Instruct-2509
4
- # Few-shot (EN/FR/DE/IT one each), deterministic by default, optional fallback sampling toggle,
5
- # soft token cap = 1024 by default, CUDA fp16 + optional 4-bit, GT scoring & downloads.
6
 
7
  import os
8
  import re
@@ -46,28 +46,20 @@ CONTEXT_GUIDE = (
46
  "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
47
  )
48
 
49
- # Few-shot: exactly ONE compact example per language
50
  FEW_SHOTS = [
51
  # EN
52
- {
53
- "transcript": "Agent: Can we meet Friday 3pm on Teams?\nClient: Yes, Friday 3pm works.\nAgent: I’ll send the invite.",
54
- "labels": ["schedule_meeting"]
55
- },
56
  # FR
57
- {
58
- "transcript": "Client: Mon numéro a changé: +41 44 000 00 00.\nConseiller: Merci, je mets à jour vos coordonnées.",
59
- "labels": ["update_contact_info_non_postal"]
60
- },
61
  # DE
62
- {
63
- "transcript": "Kunde: Neue Postadresse: Musterstrasse 1, 8000 Zürich.\nBerater: Danke, ich aktualisiere die Postadresse.",
64
- "labels": ["update_contact_info_postal_address"]
65
- },
66
  # IT
67
- {
68
- "transcript": "Cliente: Totale patrimonio confermato a 8 milioni CHF.\nConsulente: Perfetto, aggiorno i dati KYC sul totale degli asset.",
69
- "labels": ["update_kyc_total_assets"]
70
- },
71
  ]
72
 
73
  # --------------------- WRITABLE HF CACHE -----------------------------
@@ -101,6 +93,27 @@ RE_DISCLAIMER = re.compile(r"^\s*disclaimer\s*:", re.IGNORECASE)
101
  RE_DROP = re.compile(r"(readme|terms|synthetic transcript)", re.IGNORECASE)
102
  SMALLTALK_RE = re.compile(r"\b(thanks?|merci|grazie|danke|tsch(ü|u)ss|ciao|bye|ok(ay)?)\b", re.IGNORECASE)
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def _json_from_text(text: str) -> str:
105
  s = text.strip()
106
  if s.startswith("{") and s.endswith("}"):
@@ -114,8 +127,7 @@ def safe_json_labels(s: str, allowed: List[str]) -> List[str]:
114
  except Exception:
115
  return []
116
  labels = data.get("labels", [])
117
- clean = []
118
- seen = set()
119
  for lab in labels:
120
  if lab in allowed and lab not in seen:
121
  clean.append(lab); seen.add(lab)
@@ -221,6 +233,44 @@ def card_markdown(title: str, value: str, hint: str = "") -> str:
221
  </div>
222
  """
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  # -------------------------- MODEL -----------------------------------
225
 
226
  class HFModel:
@@ -264,10 +314,10 @@ class HFModel:
264
  self.model = self.model.to(DEVICE)
265
 
266
  @torch.inference_mode()
267
- def generate_json(self, prompt: str, max_new_tokens=48, allow_sampling=False) -> Tuple[str, Dict[str, int]]:
268
  """
269
  Deterministic by default. If allow_sampling=True (toggle), we use mild temperature.
270
- Returns (json_text, token_stats)
271
  """
272
  tok = self.tokenizer
273
  mdl = self.model
@@ -282,12 +332,13 @@ class HFModel:
282
  eos_token_id=tok.eos_token_id,
283
  )
284
  if allow_sampling:
285
- # mild sampling; disabled by default to avoid CUDA multinomial issues on T4
286
  kwargs.update(dict(do_sample=True, temperature=0.25, top_p=0.9))
287
  else:
288
  kwargs.update(dict(do_sample=False, temperature=0.0, top_p=1.0))
289
 
 
290
  out = mdl.generate(**inputs, **kwargs)
 
291
 
292
  prompt_tokens = int(inputs.input_ids.shape[-1])
293
  output_tokens = int(out.shape[-1] - inputs.input_ids.shape[-1])
@@ -299,7 +350,7 @@ class HFModel:
299
  "prompt_tokens": prompt_tokens,
300
  "output_tokens": output_tokens,
301
  "total_tokens": total_tokens,
302
- }
303
 
304
  _MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
305
 
@@ -323,6 +374,19 @@ def preprocess_text(txt: str, add_header: bool, strip_smalltalk: bool) -> str:
323
  cleaned = "\n".join(lines[-32768:])
324
  return f"[EMAIL/MESSAGE SIGNAL]\n{cleaned}" if add_header else cleaned
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def run_single(
327
  custom_repo_id: str,
328
  rules_json: Optional[gr.File],
@@ -341,45 +405,41 @@ def run_single(
341
  ):
342
  """Returns: repo, revision, predicted_json, metric_cards_md, diag_cards_md, raw_metrics_json"""
343
 
 
 
344
  repo = (custom_repo_id or DEFAULT_REPO).strip()
345
  revision = "main"
346
  allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
347
 
348
- # Preprocess + cap
349
- effective_len = len(transcript)
350
  if preprocess:
351
  transcript = preprocess_text(transcript, add_header, strip_smalltalk)
352
- effective_len = len(transcript)
353
-
354
- cap_info = ""
355
- if soft_token_cap and soft_token_cap > 0:
356
- approx_chars = int(soft_token_cap * 4)
357
- if len(transcript) > approx_chars:
358
- transcript = transcript[-approx_chars:]
359
- cap_info = f"(soft cap ~{soft_token_cap}t)"
360
 
361
  # Build prompt
362
  system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
363
- prompt = build_prompt(system, context_text or CONTEXT_GUIDE, transcript, allowed, use_fewshot)
364
 
365
  model = get_model(repo, revision, load_in_4bit)
366
 
367
- # Deterministic pass only (fast & stable)
368
- t0 = time.perf_counter()
369
- raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
370
  pred_labels = safe_json_labels(raw_json, allowed)
371
 
372
  # Optional fallback sampling (OFF by default)
373
  fallback_used = False
374
  if enable_fallback_sampling and not pred_labels:
375
- raw_json2, tok_stats2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
376
  pred_labels2 = safe_json_labels(raw_json2, allowed)
377
  if pred_labels2:
378
  pred_labels = pred_labels2
379
  tok_stats = tok_stats2
 
380
  fallback_used = True
381
 
382
- total_latency = time.perf_counter() - t0
383
  est_cost = (total_latency / 3600.0) * max(0.0, float(hourly_rate or 0.0))
384
 
385
  # Ground truth
@@ -406,16 +466,17 @@ def run_single(
406
  metric_cards += card("Missing labels", json.dumps(missing, ensure_ascii=False) if gt_labels is not None else "—", "Expected but not predicted")
407
  metric_cards += card("Extra labels", json.dumps(extra, ensure_ascii=False) if gt_labels is not None else "—", "Predicted but not expected")
408
 
409
- # Diagnostics cards
410
  diag_cards = ""
411
  diag_cards += card("Model / Rev", f"{repo} / {revision}")
412
  diag_cards += card("Device", f"{DEVICE} ({GPU_NAME})")
413
  diag_cards += card("Precision dtype", f"{DTYPE_FALLBACK}")
414
  diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
415
  diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
416
- diag_cards += card("Effective text length", f"{effective_len} chars {cap_info}")
417
  diag_cards += card("Tokens", f"prompt={tok_stats['prompt_tokens']}, output={tok_stats['output_tokens']}, total={tok_stats['total_tokens']}", "Token counts influence latency & cost")
418
- diag_cards += card("Latency", f"{total_latency:.2f} s", "End-to-end time")
 
419
  diag_cards += card("Cost (est.)", f"${(est_cost):.6f} @ {hourly_rate:.4f}/hr")
420
  diag_cards += card("Fallback sampling used", "Yes" if fallback_used else "No", "Sampling can be slower/unstable on T4; off by default")
421
 
@@ -431,9 +492,11 @@ def run_single(
431
  "extra": extra if gt_labels is not None else None,
432
  "per_label": per_label if gt_labels is not None else None,
433
  "token_stats": tok_stats,
434
- "latency_seconds": round(total_latency, 3),
 
435
  "estimated_cost_usd": round(est_cost, 6),
436
  "fallback_used": fallback_used,
 
437
  }
438
 
439
  return (
@@ -475,10 +538,11 @@ def run_batch(
475
  model = get_model(repo, revision, load_in_4bit)
476
 
477
  rows = [["filename","labels"]]
478
- per_sample_rows = [["filename","pred_labels","gold_labels","precision","recall","f1","exact_match","hamming_loss","missing","extra"]]
479
  totals = {"tp":0,"fp":0,"fn":0,"pred_total":0,"gold_total":0}
480
  label_global = {lab: {"tp":0,"fp":0,"fn":0} for lab in allowed}
481
- total_prompt_tokens = 0; total_output_tokens = 0; total_secs = 0.0; n=0; with_gt=0
 
482
 
483
  system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
484
 
@@ -488,29 +552,28 @@ def run_batch(
488
  except Exception:
489
  rows.append([name, "[] # unreadable"]); continue
490
 
 
 
491
  if preprocess:
492
  txt = preprocess_text(txt, add_header, strip_smalltalk)
 
493
 
494
- if soft_token_cap and soft_token_cap > 0:
495
- approx_chars = int(soft_token_cap * 4)
496
- if len(txt) > approx_chars:
497
- txt = txt[-approx_chars:]
498
-
499
- prompt = build_prompt(system, context_text or CONTEXT_GUIDE, txt, allowed, use_fewshot)
500
 
501
- t0 = time.perf_counter()
502
- raw_json, tok_stats = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
503
  pred = safe_json_labels(raw_json, allowed)
504
  if enable_fallback_sampling and not pred:
505
- raw_json2, tok_stats2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
506
  pred2 = safe_json_labels(raw_json2, allowed)
507
  if pred2:
508
- pred = pred2
509
- tok_stats = tok_stats2
 
510
 
511
- total_secs += (time.perf_counter() - t0)
512
  total_prompt_tokens += tok_stats["prompt_tokens"]
513
  total_output_tokens += tok_stats["output_tokens"]
 
 
514
  n += 1
515
 
516
  rows.append([name, json.dumps(pred, ensure_ascii=False)])
@@ -538,13 +601,21 @@ def run_batch(
538
  round(ham,4),
539
  json.dumps(missing, ensure_ascii=False),
540
  json.dumps(extra, ensure_ascii=False),
 
 
 
 
 
 
 
 
541
  ])
542
 
543
  tp, fp, fn = totals["tp"], totals["fp"], totals["fn"]
544
  prec = tp / (tp + fp) if (tp + fp) else 0.0
545
  rec = tp / (tp + fn) if (tp + fn) else 0.0
546
  f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
547
- est_cost = (total_secs / 3600.0) * max(0.0, float(hourly_rate or 0.0))
548
 
549
  coverage = {lab: 0 for lab in allowed}
550
  for r in rows[1:]:
@@ -572,8 +643,10 @@ def run_batch(
572
  "avg_prompt_tokens": round(total_prompt_tokens / n, 2) if n else 0.0,
573
  "avg_output_tokens": round(total_output_tokens / n, 2) if n else 0.0,
574
  },
575
- "latency_seconds_total": round(total_secs, 3),
576
- "avg_latency_seconds": round(total_secs / n, 3) if n else 0.0,
 
 
577
  "estimated_cost_usd": round(est_cost, 6),
578
  }
579
 
@@ -587,7 +660,8 @@ def run_batch(
587
  diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
588
  diag_cards += card("Files processed", f"{n} (with GT: {with_gt})")
589
  diag_cards += card("Tokens (totals)", f"prompt={total_prompt_tokens}, output={total_output_tokens}")
590
- diag_cards += card("Latency", f"total={summary['latency_seconds_total']} s, avg={summary['avg_latency_seconds']} s")
 
591
  diag_cards += card("Cost (est.)", f"${summary['estimated_cost_usd']} @ {hourly_rate:.4f}/hr")
592
  diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
593
 
@@ -612,13 +686,17 @@ def run_batch(
612
 
613
  # ----------------------------- UI -----------------------------------
614
 
615
- with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics (stable)") as demo:
616
  gr.Markdown(
617
  f"""
618
  # From Talk to Task — Accuracy & Diagnostics (EN/FR/DE/IT)
619
 
620
  **Default model:** `{DEFAULT_REPO}` (GPU + 4-bit recommended).
621
- Upload ground truth to compute **Precision / Recall / F1 / Exact match / Hamming loss**.
 
 
 
 
622
  Upload a **Rules JSON** (`{{"labels":[...]}}`) to override allowed labels.
623
 
624
  **Model output schema:** `{{"labels": [...]}}`
@@ -640,7 +718,7 @@ with gr.Blocks(title="From Talk to Task — Accuracy & Diagnostics (stable)") as
640
  context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
641
 
642
  with gr.Row():
643
- soft_cap = gr.Slider(512, 32768, value=1024, step=1, label="Soft token cap (approx)")
644
  preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
645
  add_header = gr.Checkbox(value=True, label="Add cues header")
646
  strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
 
1
  # app.py
2
+ # From Talk to Task — Windowed extraction + two latency measures
3
  # Model: swiss-ai/Apertus-8B-Instruct-2509
4
+ # Few-shot: 1 each EN/FR/DE/IT; deterministic by default; optional sampling fallback toggle.
5
+ # Soft token cap: 1024 by default. CUDA fp16 + optional 4-bit. GT scoring + downloads.
6
 
7
  import os
8
  import re
 
46
  "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
47
  )
48
 
49
+ # Few-shot: exactly one per language (compact)
50
  FEW_SHOTS = [
51
  # EN
52
+ {"transcript": "Agent: Can we meet Friday 3pm on Teams?\nClient: Yes, Friday 3pm works.\nAgent: I’ll send the invite.",
53
+ "labels": ["schedule_meeting"]},
 
 
54
  # FR
55
+ {"transcript": "Client: Mon numéro a changé: +41 44 000 00 00.\nConseiller: Merci, je mets à jour vos coordonnées.",
56
+ "labels": ["update_contact_info_non_postal"]},
 
 
57
  # DE
58
+ {"transcript": "Kunde: Neue Postadresse: Musterstrasse 1, 8000 Zürich.\nBerater: Danke, ich aktualisiere die Postadresse.",
59
+ "labels": ["update_contact_info_postal_address"]},
 
 
60
  # IT
61
+ {"transcript": "Cliente: Totale patrimonio confermato a 8 milioni CHF.\nConsulente: Aggiorno i dati KYC sul totale degli asset.",
62
+ "labels": ["update_kyc_total_assets"]},
 
 
63
  ]
64
 
65
  # --------------------- WRITABLE HF CACHE -----------------------------
 
93
  RE_DROP = re.compile(r"(readme|terms|synthetic transcript)", re.IGNORECASE)
94
  SMALLTALK_RE = re.compile(r"\b(thanks?|merci|grazie|danke|tsch(ü|u)ss|ciao|bye|ok(ay)?)\b", re.IGNORECASE)
95
 
96
+ # keyword windows (EN/FR/DE/IT) — expand as needed
97
+ WINDOW_KEYWORDS = [
98
+ # meeting / schedule
99
+ r"\b(meet|meeting|schedule|appointment|teams|zoom|google meet|calendar)\b",
100
+ r"\b(rendez[- ]?vous|réunion|planifier|calendrier|teams|zoom)\b",
101
+ r"\b(termin|treffen|besprechung|kalender|teams|zoom)\b",
102
+ r"\b(appuntamento|riunione|calendario|teams|zoom)\b",
103
+ # address / phone / email
104
+ r"\b(address|street|avenue|road|postcode|phone|email)\b",
105
+ r"\b(adresse|rue|avenue|code postal|téléphone|courriel|email)\b",
106
+ r"\b(adresse|straße|strasse|plz|telefon|e-?mail)\b",
107
+ r"\b(indirizzo|via|cap|telefono|e-?mail)\b",
108
+ # KYC assets / totals / origin / purpose
109
+ r"\b(total assets|net worth|portfolio|real estate|origin of assets|source of wealth|purpose of relationship)\b",
110
+ r"\b(actifs totaux|patrimoine|immobilier|origine des fonds|source de richesse|but de la relation)\b",
111
+ r"\b(gesamtverm(ö|o)gen|verm(ö|o)gen|immobilien|herkunft der verm(ö|o)genswerte|zweck der gesch(ä|a)ftsbeziehung)\b",
112
+ r"\b(patrimonio totale|immobiliare|origine dei fondi|scopo della relazione)\b",
113
+ r"\b(chf|eur|usd|cur[13]|francs?)\b",
114
+ r"\b(\d{1,3}([.'’ ]\d{3})*(,\d+)?)(\s?(chf|eur|usd))\b",
115
+ ]
116
+
117
  def _json_from_text(text: str) -> str:
118
  s = text.strip()
119
  if s.startswith("{") and s.endswith("}"):
 
127
  except Exception:
128
  return []
129
  labels = data.get("labels", [])
130
+ clean, seen = [], set()
 
131
  for lab in labels:
132
  if lab in allowed and lab not in seen:
133
  clean.append(lab); seen.add(lab)
 
233
  </div>
234
  """
235
 
236
+ # ------------------- WINDOWED EXTRACTION (fix for empty labels) -------------------
237
+
238
+ def extract_windows(text: str, max_windows: int = 6, half_span_lines: int = 3) -> str:
239
+ """
240
+ Find up to `max_windows` windows around keyword hits; each window is ±`half_span_lines` lines.
241
+ If no hits, return the FIRST 8k characters instead of last chunk (common cause of misses).
242
+ """
243
+ lines = text.splitlines()
244
+ n = len(lines)
245
+ # collect hit line indices
246
+ hits: List[int] = []
247
+ pattern = re.compile("|".join(WINDOW_KEYWORDS), re.IGNORECASE)
248
+ for i, ln in enumerate(lines):
249
+ if pattern.search(ln):
250
+ hits.append(i)
251
+ # de-duplicate and cap
252
+ unique_hits = []
253
+ seen = set()
254
+ for idx in hits:
255
+ # bucket nearby hits to avoid redundant windows
256
+ bucket = idx // 2 # coarse bucketing
257
+ if bucket not in seen:
258
+ seen.add(bucket)
259
+ unique_hits.append(idx)
260
+ unique_hits = unique_hits[:max_windows]
261
+
262
+ if not unique_hits:
263
+ # return the opening chunk; most KYC/context often appears early
264
+ return "\n".join(lines[: min(2000, n)])
265
+
266
+ # Build windows and merge
267
+ windows = []
268
+ for idx in unique_hits:
269
+ a = max(0, idx - half_span_lines)
270
+ b = min(n, idx + half_span_lines + 1)
271
+ windows.append("\n".join(lines[a:b]))
272
+ return "\n...\n".join(windows)
273
+
274
  # -------------------------- MODEL -----------------------------------
275
 
276
  class HFModel:
 
314
  self.model = self.model.to(DEVICE)
315
 
316
  @torch.inference_mode()
317
+ def generate_json(self, prompt: str, max_new_tokens=48, allow_sampling=False) -> Tuple[str, Dict[str, int], float]:
318
  """
319
  Deterministic by default. If allow_sampling=True (toggle), we use mild temperature.
320
+ Returns (json_text, token_stats, model_latency_seconds)
321
  """
322
  tok = self.tokenizer
323
  mdl = self.model
 
332
  eos_token_id=tok.eos_token_id,
333
  )
334
  if allow_sampling:
 
335
  kwargs.update(dict(do_sample=True, temperature=0.25, top_p=0.9))
336
  else:
337
  kwargs.update(dict(do_sample=False, temperature=0.0, top_p=1.0))
338
 
339
+ t0 = time.perf_counter()
340
  out = mdl.generate(**inputs, **kwargs)
341
+ model_latency = time.perf_counter() - t0
342
 
343
  prompt_tokens = int(inputs.input_ids.shape[-1])
344
  output_tokens = int(out.shape[-1] - inputs.input_ids.shape[-1])
 
350
  "prompt_tokens": prompt_tokens,
351
  "output_tokens": output_tokens,
352
  "total_tokens": total_tokens,
353
+ }, model_latency
354
 
355
  _MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
356
 
 
374
  cleaned = "\n".join(lines[-32768:])
375
  return f"[EMAIL/MESSAGE SIGNAL]\n{cleaned}" if add_header else cleaned
376
 
377
+ def window_then_cap(text: str, soft_token_cap: int) -> Tuple[str, str]:
378
+ """
379
+ Apply keyword windowing; then hard cap by approximate chars (~4 chars/token).
380
+ Returns (final_text, info_string).
381
+ """
382
+ windowed = extract_windows(text)
383
+ approx_chars = int(max(soft_token_cap, 0) * 4) if soft_token_cap else 0
384
+ info = "windowed"
385
+ if approx_chars and len(windowed) > approx_chars:
386
+ windowed = windowed[:approx_chars]
387
+ info = f"windowed + soft cap ~{soft_token_cap}t"
388
+ return windowed, info
389
+
390
  def run_single(
391
  custom_repo_id: str,
392
  rules_json: Optional[gr.File],
 
405
  ):
406
  """Returns: repo, revision, predicted_json, metric_cards_md, diag_cards_md, raw_metrics_json"""
407
 
408
+ total_t0 = time.perf_counter() # TOTAL latency starts here
409
+
410
  repo = (custom_repo_id or DEFAULT_REPO).strip()
411
  revision = "main"
412
  allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
413
 
414
+ # Preprocess + window + cap
415
+ effective_len_before = len(transcript)
416
  if preprocess:
417
  transcript = preprocess_text(transcript, add_header, strip_smalltalk)
418
+ windowed, cap_info = window_then_cap(transcript, soft_token_cap)
419
+ effective_len_after = len(windowed)
 
 
 
 
 
 
420
 
421
  # Build prompt
422
  system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
423
+ prompt = build_prompt(system, context_text or CONTEXT_GUIDE, windowed, allowed, use_fewshot)
424
 
425
  model = get_model(repo, revision, load_in_4bit)
426
 
427
+ # Deterministic pass only
428
+ raw_json, tok_stats, model_latency = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
 
429
  pred_labels = safe_json_labels(raw_json, allowed)
430
 
431
  # Optional fallback sampling (OFF by default)
432
  fallback_used = False
433
  if enable_fallback_sampling and not pred_labels:
434
+ raw_json2, tok_stats2, model_latency2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
435
  pred_labels2 = safe_json_labels(raw_json2, allowed)
436
  if pred_labels2:
437
  pred_labels = pred_labels2
438
  tok_stats = tok_stats2
439
+ model_latency = model_latency2
440
  fallback_used = True
441
 
442
+ total_latency = time.perf_counter() - total_t0
443
  est_cost = (total_latency / 3600.0) * max(0.0, float(hourly_rate or 0.0))
444
 
445
  # Ground truth
 
466
  metric_cards += card("Missing labels", json.dumps(missing, ensure_ascii=False) if gt_labels is not None else "—", "Expected but not predicted")
467
  metric_cards += card("Extra labels", json.dumps(extra, ensure_ascii=False) if gt_labels is not None else "—", "Predicted but not expected")
468
 
469
+ # Diagnostics cards — now with TWO latency measures
470
  diag_cards = ""
471
  diag_cards += card("Model / Rev", f"{repo} / {revision}")
472
  diag_cards += card("Device", f"{DEVICE} ({GPU_NAME})")
473
  diag_cards += card("Precision dtype", f"{DTYPE_FALLBACK}")
474
  diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
475
  diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
476
+ diag_cards += card("Effective text length", f"before={effective_len_before} chars → after={effective_len_after} ({cap_info})")
477
  diag_cards += card("Tokens", f"prompt={tok_stats['prompt_tokens']}, output={tok_stats['output_tokens']}, total={tok_stats['total_tokens']}", "Token counts influence latency & cost")
478
+ diag_cards += card("Model latency", f"{model_latency:.2f} s", "Time spent in model.generate(...) only")
479
+ diag_cards += card("Total latency", f"{total_latency:.2f} s", "End-to-end time (preprocess → model → postprocess)")
480
  diag_cards += card("Cost (est.)", f"${(est_cost):.6f} @ {hourly_rate:.4f}/hr")
481
  diag_cards += card("Fallback sampling used", "Yes" if fallback_used else "No", "Sampling can be slower/unstable on T4; off by default")
482
 
 
492
  "extra": extra if gt_labels is not None else None,
493
  "per_label": per_label if gt_labels is not None else None,
494
  "token_stats": tok_stats,
495
+ "model_latency_seconds": round(model_latency, 3),
496
+ "total_latency_seconds": round(total_latency, 3),
497
  "estimated_cost_usd": round(est_cost, 6),
498
  "fallback_used": fallback_used,
499
+ "cap_info": cap_info,
500
  }
501
 
502
  return (
 
538
  model = get_model(repo, revision, load_in_4bit)
539
 
540
  rows = [["filename","labels"]]
541
+ per_sample_rows = [["filename","pred_labels","gold_labels","precision","recall","f1","exact_match","hamming_loss","missing","extra","model_latency_s","total_latency_s","prompt_tokens","output_tokens"]]
542
  totals = {"tp":0,"fp":0,"fn":0,"pred_total":0,"gold_total":0}
543
  label_global = {lab: {"tp":0,"fp":0,"fn":0} for lab in allowed}
544
+ total_prompt_tokens = 0; total_output_tokens = 0; sum_model_s = 0.0; sum_total_s = 0.0
545
+ n=0; with_gt=0
546
 
547
  system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
548
 
 
552
  except Exception:
553
  rows.append([name, "[] # unreadable"]); continue
554
 
555
+ total_t0 = time.perf_counter() # TOTAL latency per file
556
+
557
  if preprocess:
558
  txt = preprocess_text(txt, add_header, strip_smalltalk)
559
+ txt_windowed, cap_info = window_then_cap(txt, soft_token_cap)
560
 
561
+ prompt = build_prompt(system, context_text or CONTEXT_GUIDE, txt_windowed, allowed, use_fewshot)
 
 
 
 
 
562
 
563
+ raw_json, tok_stats, model_latency = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
 
564
  pred = safe_json_labels(raw_json, allowed)
565
  if enable_fallback_sampling and not pred:
566
+ raw_json2, tok_stats2, model_latency2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
567
  pred2 = safe_json_labels(raw_json2, allowed)
568
  if pred2:
569
+ pred = pred2; tok_stats = tok_stats2; model_latency = model_latency2
570
+
571
+ total_latency = time.perf_counter() - total_t0
572
 
 
573
  total_prompt_tokens += tok_stats["prompt_tokens"]
574
  total_output_tokens += tok_stats["output_tokens"]
575
+ sum_model_s += model_latency
576
+ sum_total_s += total_latency
577
  n += 1
578
 
579
  rows.append([name, json.dumps(pred, ensure_ascii=False)])
 
601
  round(ham,4),
602
  json.dumps(missing, ensure_ascii=False),
603
  json.dumps(extra, ensure_ascii=False),
604
+ round(model_latency,3), round(total_latency,3),
605
+ tok_stats["prompt_tokens"], tok_stats["output_tokens"],
606
+ ])
607
+ else:
608
+ per_sample_rows.append([
609
+ name, json.dumps(pred, ensure_ascii=False), None, None, None, None, None, None, None, None,
610
+ round(model_latency,3), round(total_latency,3),
611
+ tok_stats["prompt_tokens"], tok_stats["output_tokens"],
612
  ])
613
 
614
  tp, fp, fn = totals["tp"], totals["fp"], totals["fn"]
615
  prec = tp / (tp + fp) if (tp + fp) else 0.0
616
  rec = tp / (tp + fn) if (tp + fn) else 0.0
617
  f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
618
+ est_cost = (sum_total_s / 3600.0) * max(0.0, float(hourly_rate or 0.0))
619
 
620
  coverage = {lab: 0 for lab in allowed}
621
  for r in rows[1:]:
 
643
  "avg_prompt_tokens": round(total_prompt_tokens / n, 2) if n else 0.0,
644
  "avg_output_tokens": round(total_output_tokens / n, 2) if n else 0.0,
645
  },
646
+ "latency_seconds_model_total": round(sum_model_s, 3),
647
+ "latency_seconds_total": round(sum_total_s, 3),
648
+ "avg_model_latency_seconds": round(sum_model_s / n, 3) if n else 0.0,
649
+ "avg_total_latency_seconds": round(sum_total_s / n, 3) if n else 0.0,
650
  "estimated_cost_usd": round(est_cost, 6),
651
  }
652
 
 
660
  diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
661
  diag_cards += card("Files processed", f"{n} (with GT: {with_gt})")
662
  diag_cards += card("Tokens (totals)", f"prompt={total_prompt_tokens}, output={total_output_tokens}")
663
+ diag_cards += card("Latency (model)", f"total={summary['latency_seconds_model_total']} s, avg={summary['avg_model_latency_seconds']} s")
664
+ diag_cards += card("Latency (total)", f"total={summary['latency_seconds_total']} s, avg={summary['avg_total_latency_seconds']} s")
665
  diag_cards += card("Cost (est.)", f"${summary['estimated_cost_usd']} @ {hourly_rate:.4f}/hr")
666
  diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
667
 
 
686
 
687
  # ----------------------------- UI -----------------------------------
688
 
689
+ with gr.Blocks(title="From Talk to Task — Windowed + Two Latencies") as demo:
690
  gr.Markdown(
691
  f"""
692
  # From Talk to Task — Accuracy & Diagnostics (EN/FR/DE/IT)
693
 
694
  **Default model:** `{DEFAULT_REPO}` (GPU + 4-bit recommended).
695
+ Now includes **keyword windowing** (keeps early cues) and **two latency measures**:
696
+ - **Model latency:** time spent inside the model generate call
697
+ - **Total latency:** end-to-end time (preprocess → model → postprocess)
698
+
699
+ Upload ground truth to compute **Precision / Recall / F1 / Exact match / Hamming loss**.
700
  Upload a **Rules JSON** (`{{"labels":[...]}}`) to override allowed labels.
701
 
702
  **Model output schema:** `{{"labels": [...]}}`
 
718
  context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
719
 
720
  with gr.Row():
721
+ soft_cap = gr.Slider(512, 32768, value=1024, step=1, label="Soft token cap (approx; applied after keyword windows)")
722
  preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
723
  add_header = gr.Checkbox(value=True, label="Add cues header")
724
  strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")