RishiRP commited on
Commit
b80450d
·
verified ·
1 Parent(s): 5a60238

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -741
app.py CHANGED
@@ -1,788 +1,419 @@
1
- # app.py
2
- # From Talk to Task — Windowed extraction + two latency measures
3
- # Model: swiss-ai/Apertus-8B-Instruct-2509
4
- # Few-shot: 1 each EN/FR/DE/IT; deterministic by default; optional sampling fallback toggle.
5
- # Soft token cap: 1024 by default. CUDA fp16 + optional 4-bit. GT scoring + downloads.
6
-
7
- import os
8
- import re
9
- import time
10
- import json
11
- import csv
12
- import zipfile
13
- from pathlib import Path
14
- from typing import Dict, Tuple, Optional, List
15
-
16
- import gradio as gr
17
-
18
- # --------------------------- MODEL / LABELS ---------------------------------
19
-
20
- DEFAULT_REPO = "swiss-ai/Apertus-8B-Instruct-2509"
21
-
22
- DEFAULT_LABEL_SET = [
23
- "plan_contact",
24
- "schedule_meeting",
25
- "update_contact_info_non_postal",
26
- "update_contact_info_postal_address",
27
- "update_kyc_activity",
28
- "update_kyc_origin_of_assets",
29
- "update_kyc_purpose_of_businessrelation",
30
- "update_kyc_total_assets",
31
- ]
32
 
33
- SYSTEM_INSTRUCTIONS_BASE = (
34
- "You are a task extraction assistant. Input transcript language may be English, French, "
35
- "German, or Italian. Return ONLY valid JSON with a single field:\n"
36
- '"labels": a list of strings chosen ONLY from the allowed label set.\n'
37
- "Do NOT add other fields or prose. Do NOT translate labels. If multiple labels apply, return all.\n"
38
- "If none apply, return an empty list."
39
- )
40
-
41
- CONTEXT_GUIDE = (
42
- "- plan_contact: conversation without a firm date/time\n"
43
- "- schedule_meeting: explicit date/time/modality is agreed\n"
44
- "- update_contact_info_non_postal: email/phone updates\n"
45
- "- update_contact_info_postal_address: mailing address updates\n"
46
- "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
47
- )
48
-
49
- # Few-shot: exactly one per language (compact)
50
- FEW_SHOTS = [
51
- # EN
52
- {"transcript": "Agent: Can we meet Friday 3pm on Teams?\nClient: Yes, Friday 3pm works.\nAgent: I’ll send the invite.",
53
- "labels": ["schedule_meeting"]},
54
- # FR
55
- {"transcript": "Client: Mon numéro a changé: +41 44 000 00 00.\nConseiller: Merci, je mets à jour vos coordonnées.",
56
- "labels": ["update_contact_info_non_postal"]},
57
- # DE
58
- {"transcript": "Kunde: Neue Postadresse: Musterstrasse 1, 8000 Zürich.\nBerater: Danke, ich aktualisiere die Postadresse.",
59
- "labels": ["update_contact_info_postal_address"]},
60
- # IT
61
- {"transcript": "Cliente: Totale patrimonio confermato a 8 milioni CHF.\nConsulente: Aggiorno i dati KYC sul totale degli asset.",
62
- "labels": ["update_kyc_total_assets"]},
63
- ]
64
 
65
- # --------------------- WRITABLE HF CACHE -----------------------------
66
-
67
- HOME = Path(os.environ.get("HOME", "/home/user"))
68
- CACHE_DIR = HOME / ".cache" / "huggingface"
69
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
70
- os.environ.setdefault("HF_HOME", str(CACHE_DIR))
71
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
72
-
73
- HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
74
-
75
- # -------------------- TRANSFORMERS / TORCH ---------------------------
76
-
77
- try:
78
- import torch
79
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
80
- except Exception as e:
81
- raise RuntimeError(
82
- "Missing deps. requirements.txt must include: transformers>=4.56.0, torch, accelerate, huggingface_hub, bitsandbytes, gradio"
83
- ) from e
84
-
85
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
86
- GPU_NAME = torch.cuda.get_device_name(0) if DEVICE == "cuda" else "cpu"
87
- # T4 doesn't support bf16 use fp16; CPU uses fp32
88
- DTYPE_FALLBACK = torch.float16 if DEVICE == "cuda" else torch.float32
89
-
90
- # -------------------------- HELPERS ---------------------------------
91
-
92
- RE_DISCLAIMER = re.compile(r"^\s*disclaimer\s*:", re.IGNORECASE)
93
- RE_DROP = re.compile(r"(readme|terms|synthetic transcript)", re.IGNORECASE)
94
- SMALLTALK_RE = re.compile(r"\b(thanks?|merci|grazie|danke|tsch(ü|u)ss|ciao|bye|ok(ay)?)\b", re.IGNORECASE)
95
-
96
- # keyword windows (EN/FR/DE/IT) — expand as needed
97
- WINDOW_KEYWORDS = [
98
- # meeting / schedule
99
- r"\b(meet|meeting|schedule|appointment|teams|zoom|google meet|calendar)\b",
100
- r"\b(rendez[- ]?vous|réunion|planifier|calendrier|teams|zoom)\b",
101
- r"\b(termin|treffen|besprechung|kalender|teams|zoom)\b",
102
- r"\b(appuntamento|riunione|calendario|teams|zoom)\b",
103
- # address / phone / email
104
- r"\b(address|street|avenue|road|postcode|phone|email)\b",
105
- r"\b(adresse|rue|avenue|code postal|téléphone|courriel|email)\b",
106
- r"\b(adresse|straße|strasse|plz|telefon|e-?mail)\b",
107
- r"\b(indirizzo|via|cap|telefono|e-?mail)\b",
108
- # KYC assets / totals / origin / purpose
109
- r"\b(total assets|net worth|portfolio|real estate|origin of assets|source of wealth|purpose of relationship)\b",
110
- r"\b(actifs totaux|patrimoine|immobilier|origine des fonds|source de richesse|but de la relation)\b",
111
- r"\b(gesamtverm(ö|o)gen|verm(ö|o)gen|immobilien|herkunft der verm(ö|o)genswerte|zweck der gesch(ä|a)ftsbeziehung)\b",
112
- r"\b(patrimonio totale|immobiliare|origine dei fondi|scopo della relazione)\b",
113
- r"\b(chf|eur|usd|cur[13]|francs?)\b",
114
- r"\b(\d{1,3}([.'’ ]\d{3})*(,\d+)?)(\s?(chf|eur|usd))\b",
115
- ]
116
 
117
- def _json_from_text(text: str) -> str:
118
- s = text.strip()
119
- if s.startswith("{") and s.endswith("}"):
120
- return s
121
- m = re.search(r"\{.*\}", s, re.DOTALL)
122
- return m.group(0) if m else '{"labels": []}'
123
 
124
- def safe_json_labels(s: str, allowed: List[str]) -> List[str]:
125
- try:
126
- data = json.loads(s)
127
- except Exception:
128
- return []
129
- labels = data.get("labels", [])
130
- clean, seen = [], set()
131
- for lab in labels:
132
- if lab in allowed and lab not in seen:
133
- clean.append(lab); seen.add(lab)
134
- return clean
135
-
136
- def read_rules_labels(file_obj: Optional[gr.File]) -> Optional[List[str]]:
137
- if not file_obj:
138
- return None
139
- try:
140
- data = json.loads(Path(file_obj.name).read_text(encoding="utf-8"))
141
- labs = data.get("labels", [])
142
- return [x for x in labs if isinstance(x, str)]
143
- except Exception:
144
- return None
145
 
146
- def read_single_ground_truth(file_obj: Optional[gr.File]) -> Optional[List[str]]:
147
- if not file_obj:
148
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  try:
150
- data = json.loads(Path(file_obj.name).read_text(encoding="utf-8"))
151
- labels = data.get("labels", [])
152
- return [lab for lab in labels if isinstance(lab, str)]
153
  except Exception:
154
- return None
 
 
 
 
 
 
155
 
156
- def read_batch_ground_truth_zip(zip_file: Optional[gr.File]) -> Dict[str, List[str]]:
157
- out: Dict[str, List[str]] = {}
158
- if not zip_file:
 
 
 
159
  return out
160
- try:
161
- with zipfile.ZipFile(zip_file.name) as z:
162
- for name in z.namelist():
163
- if not name.lower().endswith(".json"):
164
- continue
165
- try:
166
- data = json.loads(z.read(name).decode("utf-8", errors="replace"))
167
- labs = [lab for lab in data.get("labels", []) if isinstance(lab, str)]
168
- out[Path(name).with_suffix("").name] = labs
169
- except Exception:
170
- pass
171
- except Exception:
172
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  return out
174
 
175
- def build_fewshot_block(allowed: List[str]) -> str:
176
- shots = []
177
- for ex in FEW_SHOTS:
178
- shots.append(
179
- f"- Transcript:\n{ex['transcript']}\n- Correct labels (choose subset from {allowed}): {ex['labels']}\n"
180
- )
181
- return "\n".join(shots)
182
-
183
- def build_prompt(system: str, context: str, transcript: str, allowed: List[str], use_fewshot: bool) -> str:
184
- fewshot_section = f"\n### Examples\n{build_fewshot_block(allowed)}\n" if use_fewshot else ""
185
- return (
186
- f"### System\n{system}\n\n"
187
- f"### Allowed label set\n{allowed}\n\n"
188
- f"### Context\n{context}\n"
189
- f"{fewshot_section}\n"
190
- f"### Transcript\n{transcript}\n\n"
191
- "### Output\nReturn JSON only: {\"labels\": [...]}"
192
- )
193
-
194
- def prf1_accuracy(pred: List[str], gold: List[str]) -> Tuple[float, float, float, float, Dict[str, int]]:
195
- pset, gset = set(pred), set(gold)
196
- tp = len(pset & gset); fp = len(pset - gset); fn = len(gset - pset)
197
- prec = tp / (tp + fp) if (tp + fp) else 0.0
198
- rec = tp / (tp + fn) if (tp + fn) else 0.0
199
- f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
200
- denom = len(pset | gset)
201
- acc = (tp / denom) if denom else 1.0
202
- return prec, rec, f1, acc, {"tp": tp, "fp": fp, "fn": fn, "pred_total": len(pset), "gold_total": len(gset)}
203
-
204
- def per_label_counts(pred: List[str], gold: List[str], all_labels: List[str]) -> Dict[str, Dict[str, int]]:
205
- pset, gset = set(pred), set(gold)
206
- out = {}
207
- for lab in all_labels:
208
- tp = 1 if (lab in pset and lab in gset) else 0
209
- fp = 1 if (lab in pset and lab not in gset) else 0
210
- fn = 1 if (lab in gset and lab not in pset) else 0
211
- out[lab] = {"tp": tp, "fp": fp, "fn": fn}
212
- return out
213
-
214
- def hamming_loss(pred: List[str], gold: List[str], all_labels: List[str]) -> float:
215
- pset, gset = set(pred), set(gold)
216
- wrong = 0
217
- for lab in all_labels:
218
- in_p, in_g = (lab in pset), (lab in gset)
219
- wrong += int(in_p != in_g)
220
- return wrong / max(1, len(all_labels))
221
-
222
- def write_csv(path: Path, rows: List[List[str]]):
223
- with path.open("w", newline="", encoding="utf-8") as f:
224
- w = csv.writer(f); w.writerows(rows)
225
-
226
- def card_markdown(title: str, value: str, hint: str = "") -> str:
227
- hint_md = f"<div style='font-size:12px;opacity:0.8'>{hint}</div>" if hint else ""
228
- return f"""
229
- <div style="border:1px solid #3a3a3a;border-radius:10px;padding:10px;margin:6px">
230
- <div style="font-weight:600">{title}</div>
231
- <div style="font-size:20px;margin-top:4px">{value}</div>
232
- {hint_md}
233
- </div>
234
- """
235
-
236
- # ------------------- WINDOWED EXTRACTION (fix for empty labels) -------------------
237
-
238
- def extract_windows(text: str, max_windows: int = 6, half_span_lines: int = 3) -> str:
239
- """
240
- Find up to `max_windows` windows around keyword hits; each window is ±`half_span_lines` lines.
241
- If no hits, return the FIRST 8k characters instead of last chunk (common cause of misses).
242
- """
243
- lines = text.splitlines()
244
- n = len(lines)
245
- # collect hit line indices
246
- hits: List[int] = []
247
- pattern = re.compile("|".join(WINDOW_KEYWORDS), re.IGNORECASE)
248
- for i, ln in enumerate(lines):
249
- if pattern.search(ln):
250
- hits.append(i)
251
- # de-duplicate and cap
252
- unique_hits = []
253
- seen = set()
254
- for idx in hits:
255
- # bucket nearby hits to avoid redundant windows
256
- bucket = idx // 2 # coarse bucketing
257
- if bucket not in seen:
258
- seen.add(bucket)
259
- unique_hits.append(idx)
260
- unique_hits = unique_hits[:max_windows]
261
-
262
- if not unique_hits:
263
- # return the opening chunk; most KYC/context often appears early
264
- return "\n".join(lines[: min(2000, n)])
265
-
266
- # Build windows and merge
267
- windows = []
268
- for idx in unique_hits:
269
- a = max(0, idx - half_span_lines)
270
- b = min(n, idx + half_span_lines + 1)
271
- windows.append("\n".join(lines[a:b]))
272
- return "\n...\n".join(windows)
273
-
274
- # -------------------------- MODEL -----------------------------------
275
-
276
- class HFModel:
277
- def __init__(
278
- self,
279
- repo_id: str,
280
- revision: Optional[str],
281
- token: Optional[str],
282
- load_in_4bit: bool,
283
- dtype
284
- ):
285
  self.repo_id = repo_id
286
- self.revision = revision or "main"
287
- self.token = token
288
- self.load_in_4bit = load_in_4bit and (DEVICE == "cuda")
289
- self.dtype = dtype
290
  self.tokenizer = None
291
  self.model = None
292
 
293
  def load(self):
294
  qcfg = None
295
- if self.load_in_4bit:
296
  qcfg = BitsAndBytesConfig(
297
  load_in_4bit=True,
298
  bnb_4bit_quant_type="nf4",
299
  bnb_4bit_compute_dtype=torch.float16,
300
  bnb_4bit_use_double_quant=True,
301
  )
302
- self.tokenizer = AutoTokenizer.from_pretrained(
303
- self.repo_id, revision=self.revision, token=self.token,
304
- cache_dir=str(CACHE_DIR), use_fast=True, trust_remote_code=True
 
 
 
 
305
  )
306
- self.model = AutoModelForCausalLM.from_pretrained(
307
- self.repo_id, revision=self.revision, token=self.token,
308
- cache_dir=str(CACHE_DIR), trust_remote_code=True,
309
- torch_dtype=self.dtype,
 
 
 
 
 
 
310
  device_map="auto" if DEVICE == "cuda" else None,
311
- quantization_config=qcfg, low_cpu_mem_usage=True
 
 
312
  )
313
- if DEVICE == "cpu":
314
- self.model = self.model.to(DEVICE)
315
 
316
  @torch.inference_mode()
317
- def generate_json(self, prompt: str, max_new_tokens=48, allow_sampling=False) -> Tuple[str, Dict[str, int], float]:
318
- """
319
- Deterministic by default. If allow_sampling=True (toggle), we use mild temperature.
320
- Returns (json_text, token_stats, model_latency_seconds)
321
- """
322
- tok = self.tokenizer
323
- mdl = self.model
324
-
325
- messages = [{"role": "user", "content": prompt}]
326
- templated = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
327
- inputs = tok([templated], return_tensors="pt", add_special_tokens=False).to(mdl.device)
328
-
329
- kwargs = dict(
330
- max_new_tokens=max_new_tokens,
331
- pad_token_id=tok.eos_token_id,
332
- eos_token_id=tok.eos_token_id,
333
- )
334
- if allow_sampling:
335
- kwargs.update(dict(do_sample=True, temperature=0.25, top_p=0.9))
336
  else:
337
- kwargs.update(dict(do_sample=False, temperature=0.0, top_p=1.0))
338
-
339
- t0 = time.perf_counter()
340
- out = mdl.generate(**inputs, **kwargs)
341
- model_latency = time.perf_counter() - t0
342
-
343
- prompt_tokens = int(inputs.input_ids.shape[-1])
344
- output_tokens = int(out.shape[-1] - inputs.input_ids.shape[-1])
345
- total_tokens = prompt_tokens + output_tokens
346
-
347
- decoded = tok.decode(out[0], skip_special_tokens=True)
348
- gen = decoded[len(templated):].strip() if decoded.startswith(templated) else decoded
349
- return _json_from_text(gen), {
350
- "prompt_tokens": prompt_tokens,
351
- "output_tokens": output_tokens,
352
- "total_tokens": total_tokens,
353
- }, model_latency
354
-
355
- _MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
356
-
357
- def get_model(repo_id: str, revision: Optional[str], load_in_4bit: bool) -> HFModel:
358
- key = (repo_id, revision, load_in_4bit)
359
- if key in _MODEL_CACHE:
360
- return _MODEL_CACHE[key]
361
- mdl = HFModel(repo_id, revision, HF_TOKEN, load_in_4bit, DTYPE_FALLBACK)
362
- mdl.load()
363
- _MODEL_CACHE[key] = mdl
364
- return mdl
365
-
366
- # ---------------------- INFERENCE ROUTES ----------------------------
367
-
368
- def preprocess_text(txt: str, add_header: bool, strip_smalltalk: bool) -> str:
369
- lines = [ln.rstrip() for ln in txt.splitlines()]
370
- lines = [ln for ln in lines if not RE_DISCLAIMER.match(ln)]
371
- lines = [ln for ln in lines if not RE_DROP.search(ln)]
372
- if strip_smalltalk:
373
- lines = [ln for ln in lines if not SMALLTALK_RE.search(ln)]
374
- cleaned = "\n".join(lines[-32768:])
375
- return f"[EMAIL/MESSAGE SIGNAL]\n{cleaned}" if add_header else cleaned
376
-
377
- def window_then_cap(text: str, soft_token_cap: int) -> Tuple[str, str]:
378
- """
379
- Apply keyword windowing; then hard cap by approximate chars (~4 chars/token).
380
- Returns (final_text, info_string).
381
- """
382
- windowed = extract_windows(text)
383
- approx_chars = int(max(soft_token_cap, 0) * 4) if soft_token_cap else 0
384
- info = "windowed"
385
- if approx_chars and len(windowed) > approx_chars:
386
- windowed = windowed[:approx_chars]
387
- info = f"windowed + soft cap ~{soft_token_cap}t"
388
- return windowed, info
389
-
390
- def run_single(
391
- custom_repo_id: str,
392
- rules_json: Optional[gr.File],
393
- system_instructions: str,
394
- context_text: str,
395
- transcript: str,
396
- soft_token_cap: int,
397
- preprocess: bool,
398
- add_header: bool,
399
- strip_smalltalk: bool,
400
- load_in_4bit: bool,
401
- hourly_rate: float,
402
- gt_json_file: Optional[gr.File],
403
- use_fewshot: bool,
404
- enable_fallback_sampling: bool,
405
- ):
406
- """Returns: repo, revision, predicted_json, metric_cards_md, diag_cards_md, raw_metrics_json"""
407
-
408
- total_t0 = time.perf_counter() # TOTAL latency starts here
409
-
410
- repo = (custom_repo_id or DEFAULT_REPO).strip()
411
- revision = "main"
412
- allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
413
-
414
- # Preprocess + window + cap
415
- effective_len_before = len(transcript)
416
- if preprocess:
417
- transcript = preprocess_text(transcript, add_header, strip_smalltalk)
418
- windowed, cap_info = window_then_cap(transcript, soft_token_cap)
419
- effective_len_after = len(windowed)
420
-
421
- # Build prompt
422
- system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
423
- prompt = build_prompt(system, context_text or CONTEXT_GUIDE, windowed, allowed, use_fewshot)
424
-
425
- model = get_model(repo, revision, load_in_4bit)
426
-
427
- # Deterministic pass only
428
- raw_json, tok_stats, model_latency = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
429
- pred_labels = safe_json_labels(raw_json, allowed)
430
-
431
- # Optional fallback sampling (OFF by default)
432
- fallback_used = False
433
- if enable_fallback_sampling and not pred_labels:
434
- raw_json2, tok_stats2, model_latency2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
435
- pred_labels2 = safe_json_labels(raw_json2, allowed)
436
- if pred_labels2:
437
- pred_labels = pred_labels2
438
- tok_stats = tok_stats2
439
- model_latency = model_latency2
440
- fallback_used = True
441
-
442
- total_latency = time.perf_counter() - total_t0
443
- est_cost = (total_latency / 3600.0) * max(0.0, float(hourly_rate or 0.0))
444
-
445
- # Ground truth
446
- gt_labels = read_single_ground_truth(gt_json_file)
447
- pr = rc = f1 = acc = 0.0
448
- ham = None
449
- missing = []; extra = []; per_label = {}
450
- if gt_labels is not None:
451
- pr, rc, f1, acc, counts = prf1_accuracy(pred_labels, gt_labels)
452
- ham = hamming_loss(pred_labels, gt_labels, allowed)
453
- per_label = per_label_counts(pred_labels, gt_labels, allowed)
454
- missing = sorted(list(set(gt_labels) - set(pred_labels)))
455
- extra = sorted(list(set(pred_labels) - set(gt_labels)))
456
-
457
- # Metric cards
458
- def card(title, val, hint=""):
459
- return card_markdown(title, val, hint)
460
- metric_cards = ""
461
- metric_cards += card("Precision", f"{pr:.3f}" if gt_labels is not None else "—", "Correct positives / All predicted positives")
462
- metric_cards += card("Recall", f"{rc:.3f}" if gt_labels is not None else "—", "Correct positives / All actual positives")
463
- metric_cards += card("F1 score", f"{f1:.3f}" if gt_labels is not None else "—", "Harmonic mean of Precision & Recall")
464
- metric_cards += card("Exact match", f"{1.0 if gt_labels and set(pred_labels)==set(gt_labels) else 0.0 if gt_labels is not None else '—'}", "1.0 if predicted set equals truth")
465
- metric_cards += card("Hamming loss", f"{ham:.3f}" if ham is not None else "—", "Fraction of labels where prediction disagrees with truth (lower better)")
466
- metric_cards += card("Missing labels", json.dumps(missing, ensure_ascii=False) if gt_labels is not None else "—", "Expected but not predicted")
467
- metric_cards += card("Extra labels", json.dumps(extra, ensure_ascii=False) if gt_labels is not None else "—", "Predicted but not expected")
468
-
469
- # Diagnostics cards — now with TWO latency measures
470
- diag_cards = ""
471
- diag_cards += card("Model / Rev", f"{repo} / {revision}")
472
- diag_cards += card("Device", f"{DEVICE} ({GPU_NAME})")
473
- diag_cards += card("Precision dtype", f"{DTYPE_FALLBACK}")
474
- diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
475
- diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
476
- diag_cards += card("Effective text length", f"before={effective_len_before} chars → after={effective_len_after} ({cap_info})")
477
- diag_cards += card("Tokens", f"prompt={tok_stats['prompt_tokens']}, output={tok_stats['output_tokens']}, total={tok_stats['total_tokens']}", "Token counts influence latency & cost")
478
- diag_cards += card("Model latency", f"{model_latency:.2f} s", "Time spent in model.generate(...) only")
479
- diag_cards += card("Total latency", f"{total_latency:.2f} s", "End-to-end time (preprocess → model → postprocess)")
480
- diag_cards += card("Cost (est.)", f"${(est_cost):.6f} @ {hourly_rate:.4f}/hr")
481
- diag_cards += card("Fallback sampling used", "Yes" if fallback_used else "No", "Sampling can be slower/unstable on T4; off by default")
482
-
483
- raw_metrics = {
484
- "labels_pred": pred_labels,
485
- "ground_truth_labels": gt_labels,
486
- "precision": round(pr, 4) if gt_labels is not None else None,
487
- "recall": round(rc, 4) if gt_labels is not None else None,
488
- "f1": round(f1, 4) if gt_labels is not None else None,
489
- "exact_match": 1.0 if gt_labels and set(pred_labels)==set(gt_labels) else (0.0 if gt_labels is not None else None),
490
- "hamming_loss": round(ham, 4) if ham is not None else None,
491
- "missing": missing if gt_labels is not None else None,
492
- "extra": extra if gt_labels is not None else None,
493
- "per_label": per_label if gt_labels is not None else None,
494
- "token_stats": tok_stats,
495
- "model_latency_seconds": round(model_latency, 3),
496
- "total_latency_seconds": round(total_latency, 3),
497
- "estimated_cost_usd": round(est_cost, 6),
498
- "fallback_used": fallback_used,
499
- "cap_info": cap_info,
500
- }
501
-
502
- return (
503
- repo, revision,
504
- json.dumps({"labels": pred_labels}, ensure_ascii=False),
505
- metric_cards, diag_cards,
506
- json.dumps(raw_metrics, indent=2)
507
- )
508
 
509
- def run_batch(
510
- custom_repo_id: str,
511
- rules_json: Optional[gr.File],
512
- system_instructions: str,
513
- context_text: str,
514
- transcripts_zip: Optional[gr.File],
515
- gt_zip: Optional[gr.File],
516
- soft_token_cap: int,
517
- preprocess: bool,
518
- add_header: bool,
519
- strip_smalltalk: bool,
520
- load_in_4bit: bool,
521
- hourly_rate: float,
522
- use_fewshot: bool,
523
- enable_fallback_sampling: bool,
524
- ):
525
- repo = (custom_repo_id or DEFAULT_REPO).strip()
526
- revision = "main"
527
- if not transcripts_zip:
528
- return repo, revision, "filename,labels\n", "<div>No transcript ZIP provided.</div>", "{}", None, None, None
529
-
530
- allowed = read_rules_labels(rules_json) or DEFAULT_LABEL_SET
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  try:
532
- z = zipfile.ZipFile(transcripts_zip.name)
533
- txt_names = [n for n in z.namelist() if n.lower().endswith(".txt")]
534
  except Exception as e:
535
- return repo, revision, "filename,labels\n", f"<div>Bad transcript ZIP: {e}</div>", "{}", None, None, None
536
-
537
- gt_map = read_batch_ground_truth_zip(gt_zip)
538
- model = get_model(repo, revision, load_in_4bit)
539
-
540
- rows = [["filename","labels"]]
541
- per_sample_rows = [["filename","pred_labels","gold_labels","precision","recall","f1","exact_match","hamming_loss","missing","extra","model_latency_s","total_latency_s","prompt_tokens","output_tokens"]]
542
- totals = {"tp":0,"fp":0,"fn":0,"pred_total":0,"gold_total":0}
543
- label_global = {lab: {"tp":0,"fp":0,"fn":0} for lab in allowed}
544
- total_prompt_tokens = 0; total_output_tokens = 0; sum_model_s = 0.0; sum_total_s = 0.0
545
- n=0; with_gt=0
546
 
547
- system = system_instructions or SYSTEM_INSTRUCTIONS_BASE
 
548
 
549
- for name in txt_names:
550
- try:
551
- txt = z.read(name).decode("utf-8", errors="replace")
552
- except Exception:
553
- rows.append([name, "[] # unreadable"]); continue
554
-
555
- total_t0 = time.perf_counter() # TOTAL latency per file
556
-
557
- if preprocess:
558
- txt = preprocess_text(txt, add_header, strip_smalltalk)
559
- txt_windowed, cap_info = window_then_cap(txt, soft_token_cap)
560
-
561
- prompt = build_prompt(system, context_text or CONTEXT_GUIDE, txt_windowed, allowed, use_fewshot)
562
-
563
- raw_json, tok_stats, model_latency = model.generate_json(prompt, max_new_tokens=48, allow_sampling=False)
564
- pred = safe_json_labels(raw_json, allowed)
565
- if enable_fallback_sampling and not pred:
566
- raw_json2, tok_stats2, model_latency2 = model.generate_json(prompt, max_new_tokens=48, allow_sampling=True)
567
- pred2 = safe_json_labels(raw_json2, allowed)
568
- if pred2:
569
- pred = pred2; tok_stats = tok_stats2; model_latency = model_latency2
570
-
571
- total_latency = time.perf_counter() - total_t0
572
-
573
- total_prompt_tokens += tok_stats["prompt_tokens"]
574
- total_output_tokens += tok_stats["output_tokens"]
575
- sum_model_s += model_latency
576
- sum_total_s += total_latency
577
- n += 1
578
-
579
- rows.append([name, json.dumps(pred, ensure_ascii=False)])
580
-
581
- stem = Path(name).with_suffix("").name
582
- gold = gt_map.get(stem)
583
- if gold is not None:
584
- with_gt += 1
585
- pr, rc, f1, acc, counts = prf1_accuracy(pred, gold)
586
- ham = hamming_loss(pred, gold, allowed)
587
- missing = sorted(list(set(gold) - set(pred)))
588
- extra = sorted(list(set(pred) - set(gold)))
589
- for k in ["tp","fp","fn","pred_total","gold_total"]:
590
- totals[k] += counts[k]
591
- pl = per_label_counts(pred, gold, allowed)
592
- for lab, c in pl.items():
593
- for k in ["tp","fp","fn"]:
594
- label_global[lab][k] += c[k]
595
- per_sample_rows.append([
596
- name,
597
- json.dumps(pred, ensure_ascii=False),
598
- json.dumps(gold, ensure_ascii=False),
599
- round(pr,4), round(rc,4), round(f1,4),
600
- 1.0 if set(pred)==set(gold) else 0.0,
601
- round(ham,4),
602
- json.dumps(missing, ensure_ascii=False),
603
- json.dumps(extra, ensure_ascii=False),
604
- round(model_latency,3), round(total_latency,3),
605
- tok_stats["prompt_tokens"], tok_stats["output_tokens"],
606
- ])
607
- else:
608
- per_sample_rows.append([
609
- name, json.dumps(pred, ensure_ascii=False), None, None, None, None, None, None, None, None,
610
- round(model_latency,3), round(total_latency,3),
611
- tok_stats["prompt_tokens"], tok_stats["output_tokens"],
612
- ])
613
-
614
- tp, fp, fn = totals["tp"], totals["fp"], totals["fn"]
615
- prec = tp / (tp + fp) if (tp + fp) else 0.0
616
- rec = tp / (tp + fn) if (tp + fn) else 0.0
617
- f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
618
- est_cost = (sum_total_s / 3600.0) * max(0.0, float(hourly_rate or 0.0))
619
-
620
- coverage = {lab: 0 for lab in allowed}
621
- for r in rows[1:]:
622
- try:
623
- labs = set(json.loads(r[1]))
624
- for lab in labs:
625
- if lab in coverage:
626
- coverage[lab] += 1
627
- except Exception:
628
- pass
629
-
630
- summary = {
631
- "files_processed": n,
632
- "files_with_ground_truth": with_gt,
633
- "labels_allowed": allowed,
634
- "precision_micro": round(prec, 4),
635
- "recall_micro": round(rec, 4),
636
- "f1_micro": round(f1, 4),
637
- "per_label_counts": label_global,
638
- "coverage_counts": coverage,
639
- "token_stats": {
640
- "prompt_tokens_total": total_prompt_tokens,
641
- "output_tokens_total": total_output_tokens,
642
- "total_tokens": total_prompt_tokens + total_output_tokens,
643
- "avg_prompt_tokens": round(total_prompt_tokens / n, 2) if n else 0.0,
644
- "avg_output_tokens": round(total_output_tokens / n, 2) if n else 0.0,
645
- },
646
- "latency_seconds_model_total": round(sum_model_s, 3),
647
- "latency_seconds_total": round(sum_total_s, 3),
648
- "avg_model_latency_seconds": round(sum_model_s / n, 3) if n else 0.0,
649
- "avg_total_latency_seconds": round(sum_total_s / n, 3) if n else 0.0,
650
- "estimated_cost_usd": round(est_cost, 6),
651
- }
652
-
653
- # Diagnostic cards (HTML)
654
- diag_cards = ""
655
- def card(title, val, hint=""):
656
- return card_markdown(title, val, hint)
657
- diag_cards += card("Model / Rev", f"{repo} / {revision}")
658
- diag_cards += card("Device", f"{DEVICE} ({GPU_NAME})")
659
- diag_cards += card("Precision dtype", f"{DTYPE_FALLBACK}")
660
- diag_cards += card("4-bit", f"{bool(load_in_4bit)}")
661
- diag_cards += card("Files processed", f"{n} (with GT: {with_gt})")
662
- diag_cards += card("Tokens (totals)", f"prompt={total_prompt_tokens}, output={total_output_tokens}")
663
- diag_cards += card("Latency (model)", f"total={summary['latency_seconds_model_total']} s, avg={summary['avg_model_latency_seconds']} s")
664
- diag_cards += card("Latency (total)", f"total={summary['latency_seconds_total']} s, avg={summary['avg_total_latency_seconds']} s")
665
- diag_cards += card("Cost (est.)", f"${summary['estimated_cost_usd']} @ {hourly_rate:.4f}/hr")
666
- diag_cards += card("Allowed labels", json.dumps(allowed, ensure_ascii=False))
667
-
668
- # Artifacts
669
- tmp_dir = Path("/tmp")
670
- pred_csv = tmp_dir / "predictions.csv"
671
- per_sample_csv = tmp_dir / "per_sample_metrics.csv"
672
- summary_json = tmp_dir / "summary_metrics.json"
673
- with pred_csv.open("w", newline="", encoding="utf-8") as f:
674
- w = csv.writer(f); w.writerows(rows)
675
- with per_sample_csv.open("w", newline="", encoding="utf-8") as f:
676
- w = csv.writer(f); w.writerows(per_sample_rows)
677
- summary_json.write_text(json.dumps(summary, indent=2), encoding="utf-8")
678
-
679
- return (
680
- repo, revision,
681
- "\n".join([",".join(r) for r in rows]),
682
- diag_cards,
683
- json.dumps(summary, indent=2),
684
- str(pred_csv), str(per_sample_csv), str(summary_json)
685
  )
686
 
687
- # ----------------------------- UI -----------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
- with gr.Blocks(title="From Talk to Task — Windowed + Two Latencies") as demo:
 
690
  gr.Markdown(
691
- f"""
692
- # From Talk to Task Accuracy & Diagnostics (EN/FR/DE/IT)
693
-
694
- **Default model:** `{DEFAULT_REPO}` (GPU + 4-bit recommended).
695
- Now includes **keyword windowing** (keeps early cues) and **two latency measures**:
696
- - **Model latency:** time spent inside the model generate call
697
- - **Total latency:** end-to-end time (preprocess → model → postprocess)
698
-
699
- Upload ground truth to compute **Precision / Recall / F1 / Exact match / Hamming loss**.
700
- Upload a **Rules JSON** (`{{"labels":[...]}}`) to override allowed labels.
701
-
702
- **Model output schema:** `{{"labels": [...]}}`
703
- """
704
  )
705
 
706
  with gr.Row():
707
- custom_repo = gr.Textbox(
708
- label="Model repo (empty → default)",
709
- placeholder="e.g. swiss-ai/Apertus-8B-Instruct-2509"
710
- )
711
- load_4bit = gr.Checkbox(value=True, label="Load in 4-bit (GPU only)")
712
- use_fewshot = gr.Checkbox(value=True, label="Use few-shot (1 per EN/FR/DE/IT)")
713
- enable_fallback_sampling = gr.Checkbox(value=False, label="Enable fallback sampling (slower/unstable on T4)")
714
-
715
- rules_file = gr.File(label="Rules JSON (optional) — overrides allowed labels", file_types=[".json"])
716
-
717
- system = gr.Textbox(label="Instructions (System)", value=SYSTEM_INSTRUCTIONS_BASE, lines=6)
718
- context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
719
-
720
- with gr.Row():
721
- soft_cap = gr.Slider(512, 32768, value=1024, step=1, label="Soft token cap (approx; applied after keyword windows)")
722
- preprocess = gr.Checkbox(value=True, label="Enable preprocessing")
723
- add_header = gr.Checkbox(value=True, label="Add cues header")
724
- strip_smalltalk = gr.Checkbox(value=False, label="Strip smalltalk")
725
- hourly_rate = gr.Number(value=0.40, precision=4, label="Hourly hardware price (USD) for cost estimate")
726
-
727
- with gr.Tabs():
728
- with gr.Tab("Single Transcript"):
729
- transcript = gr.Textbox(label="Paste transcript (EN/FR/DE/IT)", lines=14)
730
- gt_single = gr.File(label="Ground truth JSON — {\"labels\": [..]}", file_types=[".json"])
731
- run_btn = gr.Button("Run (Single)", variant="primary")
732
-
733
- repo_used = gr.Textbox(label="Repo used", interactive=False)
734
- rev_used = gr.Textbox(label="Revision", interactive=False)
735
- json_out = gr.Code(label="Predicted JSON", language="json")
736
-
737
- metric_cards_md = gr.HTML(label="Metrics (cards)")
738
- diag_cards_md = gr.HTML(label="Diagnostics (cards)")
739
- raw_metrics = gr.Code(label="Raw metrics JSON", language="json")
740
-
741
- def _single(*args):
742
- return run_single(*args)
743
-
744
- run_btn.click(
745
- _single,
746
- inputs=[
747
- custom_repo, rules_file, system, context, transcript,
748
- soft_cap, preprocess, add_header, strip_smalltalk,
749
- load_4bit, hourly_rate, gt_single, use_fewshot, enable_fallback_sampling
750
- ],
751
- outputs=[repo_used, rev_used, json_out, metric_cards_md, diag_cards_md, raw_metrics],
752
  )
753
-
754
- with gr.Tab("Batch (ZIP)"):
755
- zip_in = gr.File(label="Upload ZIP of .txt transcripts", file_types=[".zip"])
756
- gt_zip = gr.File(label="Upload ZIP of ground truth .json (match basenames)", file_types=[".zip"])
757
- run_batch_btn = gr.Button("Run (Batch)", variant="primary")
758
-
759
- repo_used_b = gr.Textbox(label="Repo used", interactive=False)
760
- rev_used_b = gr.Textbox(label="Revision", interactive=False)
761
- csv_out = gr.Textbox(label="Predictions CSV (filename,labels)", lines=12)
762
-
763
- diag_cards_b = gr.HTML(label="Diagnostics (cards)")
764
- metrics_out_b = gr.Code(label="Summary metrics JSON", language="json")
765
-
766
- preds_file = gr.File(label="Download predictions.csv")
767
- per_sample_file = gr.File(label="Download per_sample_metrics.csv")
768
- summary_file = gr.File(label="Download summary_metrics.json")
769
-
770
- def _batch(*args):
771
- return run_batch(*args)
772
-
773
- run_batch_btn.click(
774
- _batch,
775
- inputs=[
776
- custom_repo, rules_file, system, context, zip_in, gt_zip,
777
- soft_cap, preprocess, add_header, strip_smalltalk,
778
- load_4bit, hourly_rate, use_fewshot, enable_fallback_sampling
779
- ],
780
- outputs=[repo_used_b, rev_used_b, csv_out, diag_cards_b, metrics_out_b, preds_file, per_sample_file, summary_file],
781
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
 
783
- gr.Markdown(
784
- f"- **HF_TOKEN:** {'✅ set' if HF_TOKEN else '⚠️ not set (only needed for gated/private)'} \n"
785
- f"- **Device:** {DEVICE} ({GPU_NAME}) | **DType:** {DTYPE_FALLBACK} | **Cache dir:** `{CACHE_DIR}`"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  )
787
 
788
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ Allowed Labels (strict, case-insensitive match; output must use canonical label text exactly):
3
+ {allowed_labels_list}
4
+
5
+ Instructions:
6
+ 1) Extract every concrete task the advisor or client must take.
7
+ 2) For each, choose ONE label from Allowed Labels (or leave empty if none match).
8
+ 3) Output STRICT JSON only, no prose:
9
+ {{
10
+ "labels": ["LabelA","LabelB", ...],
11
+ "tasks": [
12
+ {{"label": "LabelA", "explanation": "…", "evidence": "…"}},
13
+ {{"label": "LabelB", "explanation": "…", "evidence": "…"}}
14
+ ]
15
+ }}
16
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # =========================
19
+ # Utilities
20
+ # =========================
21
+ def _now_ms() -> int:
22
+ return int(time.time() * 1000)
23
+
24
+ def read_file_to_text(file: gr.File) -> str:
25
+ if not file or not file.name:
26
+ return ""
27
+ name = file.name.lower()
28
+ data = file.read()
29
+ # Restrict to light parsers (txt/md/json) for speed/reliability
30
+ if name.endswith(".json"):
31
+ try:
32
+ obj = json.loads(data.decode("utf-8", errors="ignore"))
33
+ # Accept either {"transcript": "..."} or list/str
34
+ if isinstance(obj, dict) and "transcript" in obj:
35
+ return str(obj["transcript"])
36
+ return json.dumps(obj, ensure_ascii=False)
37
+ except Exception:
38
+ return data.decode("utf-8", errors="ignore")
39
+ else:
40
+ # txt / md or anything texty
41
+ try:
42
+ return data.decode("utf-8", errors="ignore")
43
+ except Exception:
44
+ try:
45
+ return data.decode("latin-1", errors="ignore")
46
+ except Exception:
47
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ def normalize_labels(labels: List[str]) -> List[str]:
50
+ return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
 
 
 
 
51
 
52
+ def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
53
+ """
54
+ Build a case-insensitive map: lowercase -> canonical label
55
+ """
56
+ m = {}
57
+ for lab in allowed:
58
+ m[lab.lower()] = lab
59
+ return m
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ def robust_json_extract(text: str) -> Dict[str, Any]:
62
+ """
63
+ Try to parse strict JSON from model output.
64
+ If the model added extra tokens, strip to first {...} block.
65
+ """
66
+ if not text:
67
+ return {"labels": [], "tasks": []}
68
+
69
+ # Find first JSON object
70
+ start = text.find("{")
71
+ end = text.rfind("}")
72
+ if start != -1 and end != -1 and end > start:
73
+ candidate = text[start : end + 1]
74
+ else:
75
+ candidate = text
76
+
77
+ # Remove trailing junk commas and try json.loads
78
  try:
79
+ return json.loads(candidate)
 
 
80
  except Exception:
81
+ # Fallback: try to repair common issues
82
+ candidate = re.sub(r",\s*}", "}", candidate)
83
+ candidate = re.sub(r",\s*]", "]", candidate)
84
+ try:
85
+ return json.loads(candidate)
86
+ except Exception:
87
+ return {"labels": [], "tasks": []}
88
 
89
+ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
90
+ """
91
+ Keep only tasks whose label ∈ allowed; map case-insensitively to canonical.
92
+ """
93
+ out = {"labels": [], "tasks": []}
94
+ if not isinstance(pred, dict):
95
  return out
96
+ raw_labels = pred.get("labels", []) or []
97
+ raw_tasks = pred.get("tasks", []) or []
98
+
99
+ allowed_map = canonicalize_map(allowed)
100
+
101
+ # Filter labels
102
+ filt_labels: List[str] = []
103
+ for l in raw_labels:
104
+ if not isinstance(l, str):
105
+ continue
106
+ k = l.strip().lower()
107
+ if k in allowed_map:
108
+ filt_labels.append(allowed_map[k])
109
+ filt_labels = normalize_labels(filt_labels)
110
+
111
+ # Filter tasks
112
+ filt_tasks = []
113
+ for t in raw_tasks:
114
+ if not isinstance(t, dict):
115
+ continue
116
+ lbl = t.get("label", "")
117
+ k = str(lbl).strip().lower()
118
+ if k in allowed_map:
119
+ new_t = dict(t)
120
+ new_t["label"] = allowed_map[k]
121
+ filt_tasks.append(new_t)
122
+
123
+ # Ensure labels reflect tasks (union)
124
+ from_tasks = [tt["label"] for tt in filt_tasks if isinstance(tt.get("label"), str)]
125
+ merged = normalize_labels(list(set(filt_labels) | set(from_tasks)))
126
+
127
+ out["labels"] = merged
128
+ out["tasks"] = filt_tasks
129
  return out
130
 
131
+ def truncate_tokens(tokenizer, text: str, max_input_tokens: int) -> str:
132
+ if max_input_tokens <= 0:
133
+ return text
134
+ toks = tokenizer(text, add_special_tokens=False, return_attention_mask=False, return_tensors=None)["input_ids"]
135
+ if len(toks) <= max_input_tokens:
136
+ return text
137
+ # Keep the tail (most recent part of the convo often carries actionable tasks)
138
+ keep_ids = toks[-max_input_tokens:]
139
+ return tokenizer.decode(keep_ids, skip_special_tokens=True)
140
+
141
+ # =========================
142
+ # Model Loading
143
+ # =========================
144
+ class ModelWrapper:
145
+ def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  self.repo_id = repo_id
147
+ self.hf_token = hf_token
148
+ self.load_in_4bit = load_in_4bit
 
 
149
  self.tokenizer = None
150
  self.model = None
151
 
152
  def load(self):
153
  qcfg = None
154
+ if self.load_in_4bit and DEVICE == "cuda":
155
  qcfg = BitsAndBytesConfig(
156
  load_in_4bit=True,
157
  bnb_4bit_quant_type="nf4",
158
  bnb_4bit_compute_dtype=torch.float16,
159
  bnb_4bit_use_double_quant=True,
160
  )
161
+
162
+ tok = AutoTokenizer.from_pretrained(
163
+ self.repo_id,
164
+ token=self.hf_token,
165
+ cache_dir=str(SPACE_CACHE),
166
+ trust_remote_code=True,
167
+ use_fast=True,
168
  )
169
+ # Some models lack pad token—safe default
170
+ if tok.pad_token is None and tok.eos_token is not None:
171
+ tok.pad_token = tok.eos_token
172
+
173
+ model = AutoModelForCausalLM.from_pretrained(
174
+ self.repo_id,
175
+ token=self.hf_token,
176
+ cache_dir=str(SPACE_CACHE),
177
+ trust_remote_code=True,
178
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
179
  device_map="auto" if DEVICE == "cuda" else None,
180
+ low_cpu_mem_usage=True,
181
+ quantization_config=qcfg,
182
+ attn_implementation="sdpa", # T4-safe and faster than 'eager'
183
  )
184
+ self.tokenizer = tok
185
+ self.model = model
186
 
187
  @torch.inference_mode()
188
+ def generate(self, system_prompt: str, user_prompt: str) -> str:
189
+ # Chat template if available; otherwise a simple format
190
+ if hasattr(self.tokenizer, "apply_chat_template"):
191
+ messages = [
192
+ {"role": "system", "content": system_prompt},
193
+ {"role": "user", "content": user_prompt},
194
+ ]
195
+ input_ids = self.tokenizer.apply_chat_template(
196
+ messages,
197
+ add_generation_prompt=True,
198
+ return_tensors="pt",
199
+ ).to(self.model.device)
 
 
 
 
 
 
 
200
  else:
201
+ text = f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n"
202
+ input_ids = self.tokenizer(text, return_tensors="pt").to(self.model.device)
203
+
204
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
205
+ out_ids = self.model.generate(
206
+ **input_ids,
207
+ generation_config=GEN_CONFIG,
208
+ eos_token_id=self.tokenizer.eos_token_id,
209
+ pad_token_id=self.tokenizer.pad_token_id,
210
+ )
211
+ out = self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
212
+ # Heuristic: strip the prompting part if the model echoes input
213
+ if "}" in out:
214
+ tail = out[out.rfind("}") + 1 :]
215
+ body = out[: out.rfind("}") + 1]
216
+ # Prefer the last JSON object if multiple
217
+ if "{" in tail and "}" in tail:
218
+ # do nothing—rare; handled by robust_json_extract
219
+ pass
220
+ return body
221
+ return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ # Keep one live model per repo for snappy re-runs
224
+ _MODEL_CACHE: Dict[str, ModelWrapper] = {}
225
+
226
+ def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
227
+ key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
228
+ if key not in _MODEL_CACHE:
229
+ mw = ModelWrapper(repo_id, hf_token, load_in_4bit)
230
+ mw.load()
231
+ _MODEL_CACHE[key] = mw
232
+ return _MODEL_CACHE[key]
233
+
234
+ # =========================
235
+ # Inference Pipeline
236
+ # =========================
237
+ def run_extraction(
238
+ transcript_text: str,
239
+ transcript_file: gr.File,
240
+ allowed_labels_text: str,
241
+ model_repo: str,
242
+ use_4bit: bool,
243
+ max_input_tokens: int,
244
+ hf_token: str,
245
+ ) -> Tuple[str, str, str, str]:
246
+
247
+ t0 = _now_ms()
248
+
249
+ # 1) Get transcript: prefer file (drag-drop), else textarea
250
+ raw_text = ""
251
+ if transcript_file:
252
+ raw_text = read_file_to_text(transcript_file)
253
+ if not raw_text:
254
+ raw_text = transcript_text or ""
255
+ raw_text = raw_text.strip()
256
+
257
+ if not raw_text:
258
+ return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, ensure_ascii=False, indent=2)
259
+
260
+ # 2) Allowed labels: combine UI text with default (so we NEVER end up empty)
261
+ user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
262
+ allowed = normalize_labels(user_allowed or DEFAULT_ALLOWED_LABELS)
263
+
264
+ # 3) Load model
265
+ hf_tok = hf_token.strip() or None
266
  try:
267
+ model = get_model(model_repo, hf_tok, load_in_4bit=use_4bit)
 
268
  except Exception as e:
269
+ msg = (
270
+ f"Model load failed for '{model_repo}'. If gated/private, set HF_TOKEN in Space secrets.\n"
271
+ f"Error: {e}"
272
+ )
273
+ return "", "", msg, json.dumps({"labels": [], "tasks": []}, ensure_ascii=False, indent=2)
 
 
 
 
 
 
274
 
275
+ # 4) Truncate input to speed up
276
+ trunc_text = truncate_tokens(model.tokenizer, raw_text, max_input_tokens=max_input_tokens)
277
 
278
+ # 5) Build prompts
279
+ allowed_list_str = "\n".join(f"- {lab}" for lab in allowed)
280
+ user_prompt = USER_PROMPT_TEMPLATE.format(
281
+ transcript=trunc_text,
282
+ allowed_labels_list=allowed_list_str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  )
284
 
285
+ # 6) Generate
286
+ t1 = _now_ms()
287
+ try:
288
+ model_out = model.generate(SYSTEM_PROMPT, user_prompt)
289
+ except Exception as e:
290
+ return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, ensure_ascii=False, indent=2)
291
+ t2 = _now_ms()
292
+
293
+ # 7) Parse & filter strictly to allowed
294
+ parsed = robust_json_extract(model_out)
295
+ filtered = restrict_to_allowed(parsed, allowed)
296
+
297
+ # 8) Compose UI outputs
298
+ # Diagnostics
299
+ diag = [
300
+ f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
301
+ f"Model: {model_repo}",
302
+ f"Tokens (input, approx): ≤ {max_input_tokens}",
303
+ f"Latency: load+prep {(t1 - t0)} ms, generate {(t2 - t1)} ms, total {(t2 - t0)} ms",
304
+ f"Allowed Labels Used (n={len(allowed)}): {', '.join(allowed)}",
305
+ ]
306
+ diag_str = "\n".join(diag)
307
+
308
+ # Summary plain text
309
+ labs = filtered.get("labels", [])
310
+ tasks = filtered.get("tasks", [])
311
+ summ_lines = []
312
+ if labs:
313
+ summ_lines.append("Detected labels:\n - " + "\n - ".join(labs))
314
+ else:
315
+ summ_lines.append("Detected labels: (none)")
316
+
317
+ if tasks:
318
+ summ_lines.append("\nTasks:")
319
+ for t in tasks:
320
+ lab = t.get("label", "")
321
+ expl = t.get("explanation", "")
322
+ ev = t.get("evidence", "")
323
+ summ_lines.append(f"• [{lab}] {expl} | evidence: {ev[:140]}{'…' if len(ev)>140 else ''}")
324
+ else:
325
+ summ_lines.append("\nTasks: (none)")
326
+
327
+ summary = "\n".join(summ_lines)
328
+
329
+ # JSON pretty
330
+ json_str = json.dumps(filtered, ensure_ascii=False, indent=2)
331
+
332
+ # Raw model text (to help debug label empty issues)
333
+ raw_out = model_out.strip()
334
+
335
+ return summary, json_str, diag_str, raw_out
336
+
337
+ # =========================
338
+ # UI
339
+ # =========================
340
+ MODEL_CHOICES = [
341
+ "swiss-ai/Apertus-8B-Instruct-2509", # default
342
+ "meta-llama/Meta-Llama-3-8B-Instruct", # may be gated; handled in code
343
+ "mistralai/Mistral-7B-Instruct-v0.3", # widely available, strong baseline
344
+ ]
345
 
346
+ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
347
+ gr.Markdown("# Talk2Task — Task Extraction Demo")
348
  gr.Markdown(
349
+ "Drop a transcript file **or** paste text, choose a model, and get strict JSON back. "
350
+ "For best speed, keep inputs concise or lower the input token limit."
 
 
 
 
 
 
 
 
 
 
 
351
  )
352
 
353
  with gr.Row():
354
+ with gr.Column(scale=3):
355
+ transcript_file = gr.File(
356
+ label="Drag & drop transcript (.txt / .md / .json)",
357
+ file_types=[".txt", ".md", ".json"],
358
+ type="filepath",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  )
360
+ transcript_text = gr.Textbox(
361
+ label="Or paste transcript here",
362
+ lines=14,
363
+ placeholder="Paste conversation transcript…",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  )
365
+ allowed_labels_text = gr.Textbox(
366
+ label="Allowed Labels (one per line) — leave empty to use defaults",
367
+ value="",
368
+ lines=8,
369
+ )
370
+ with gr.Column(scale=2):
371
+ model_repo = gr.Dropdown(
372
+ label="Model Repository",
373
+ choices=MODEL_CHOICES,
374
+ value=MODEL_CHOICES[0],
375
+ )
376
+ use_4bit = gr.Checkbox(
377
+ label="Use 4-bit quantization (recommended on GPU/T4)",
378
+ value=True,
379
+ )
380
+ max_input_tokens = gr.Slider(
381
+ label="Max input tokens (truncate from end for speed)",
382
+ minimum=1024,
383
+ maximum=8192,
384
+ step=512,
385
+ value=4096,
386
+ )
387
+ hf_token = gr.Textbox(
388
+ label="HF_TOKEN (only needed for gated/private models)",
389
+ type="password",
390
+ value=os.environ.get("HF_TOKEN", ""),
391
+ )
392
+ run_btn = gr.Button("Run Extraction", variant="primary")
393
 
394
+ with gr.Row():
395
+ with gr.Column():
396
+ summary_out = gr.Textbox(label="Summary", lines=10)
397
+ with gr.Column():
398
+ json_out = gr.Code(label="Strict JSON Output", language="json")
399
+ with gr.Row():
400
+ with gr.Column():
401
+ diag_out = gr.Textbox(label="Diagnostics & Timing", lines=8)
402
+ with gr.Column():
403
+ raw_out = gr.Textbox(label="Raw Model Output (debug)", lines=8)
404
+
405
+ run_btn.click(
406
+ fn=run_extraction,
407
+ inputs=[
408
+ transcript_text,
409
+ transcript_file,
410
+ allowed_labels_text,
411
+ model_repo,
412
+ use_4bit,
413
+ max_input_tokens,
414
+ hf_token,
415
+ ],
416
+ outputs=[summary_out, json_out, diag_out, raw_out],
417
  )
418
 
419
  if __name__ == "__main__":