RishiRP commited on
Commit
6acd2cc
·
verified ·
1 Parent(s): 7cd757f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +447 -624
app.py CHANGED
@@ -1,232 +1,149 @@
1
- """
2
- Gradio application for the Swiss {ai} Weeks "From Talk to Task" challenge.
3
-
4
- This app provides two modes of operation:
5
-
6
- * **Single transcript** – Paste or upload a single conversation transcript
7
- and the model will extract actionable tasks according to a
8
- predefined list of labels. It outputs a human‑readable summary,
9
- strict JSON, and diagnostic information (e.g. device, latency).
10
-
11
- * **Batch evaluation** – Upload a ZIP archive containing one `.txt`
12
- transcript per call and a matching `.json` file with the ground
13
- truth labels. The app runs the model on each transcript, compares
14
- the predictions against the true labels and computes the official
15
- weighted score used by the challenge organisers. It also reports
16
- precision, recall and F1, and provides a per‑sample results table
17
- that can be downloaded as CSV.
18
-
19
- The official allowed labels and evaluation function are taken from
20
- the challenge repository README【235032860356166†L37-L51】【235032860356166†L76-L90】. False negatives are penalised twice as
21
- heavily as false positives, so recall is especially important.
22
- """
23
-
24
- import os
25
- import io
26
- import re
27
- import json
28
- import time
29
- import zipfile
30
- from pathlib import Path
31
- from typing import List, Dict, Any, Tuple, Optional
32
-
33
- import gradio as gr
34
- import numpy as np
35
- import pandas as pd
36
- import torch
37
- from transformers import (
38
- AutoTokenizer,
39
- AutoModelForCausalLM,
40
- BitsAndBytesConfig,
41
- GenerationConfig,
42
- )
43
-
44
-
45
- # =============================================================================
46
- # Configuration and Constants
47
- # =============================================================================
48
-
49
- # Cache directory for HuggingFace models
50
- SPACE_CACHE = Path.home() / ".cache" / "huggingface"
51
- SPACE_CACHE.mkdir(parents=True, exist_ok=True)
52
-
53
- # Device selection
54
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
55
-
56
- # Generation parameters tuned for speed/quality
57
- GEN_CONFIG = GenerationConfig(
58
- temperature=0.2,
59
- top_p=0.9,
60
- do_sample=False,
61
- max_new_tokens=256,
62
- )
63
-
64
- # Official allowed task labels【235032860356166†L37-L51】
65
- DEFAULT_ALLOWED_LABELS = [
66
- "plan_contact",
67
- "schedule_meeting",
68
- "update_contact_info_non_postal",
69
- "update_contact_info_postal_address",
70
- "update_kyc_activity",
71
- "update_kyc_origin_of_assets",
72
- "update_kyc_purpose_of_businessrelation",
73
- "update_kyc_total_assets",
74
- ]
75
-
76
- # System and user prompt templates
77
- SYSTEM_PROMPT = (
78
- "You are a precise banking assistant that extracts ACTIONABLE TASKS "
79
- "from client–advisor transcripts. Return STRICT JSON with fields: "
80
- '{"labels": ["<Label1>", ...], "tasks": [{"label": "<Label1>", "explanation": "<why>", "evidence": "<span>"}]} '"
81
- "Only use labels from the provided Allowed Labels list; if none apply, return an empty list."
82
- )
83
 
84
- USER_PROMPT_TEMPLATE = """Transcript:
85
- ```
86
- {transcript}
87
- ```
88
-
89
- Allowed Labels:
90
  {allowed_labels_list}
91
 
92
- Output STRICT JSON only, no prose:
93
- {{
94
- "labels": ["LabelA", "LabelB", ...],
95
- "tasks": [
96
- {{"label": "LabelA", "explanation": "…", "evidence": "…"}},
97
- {{"label": "LabelB", "explanation": "…", "evidence": "…"}}
98
- ]
99
- }}
100
- """
101
-
102
 
103
- # =============================================================================
104
- # Utility Functions
105
- # =============================================================================
 
 
106
 
 
 
 
107
  def _now_ms() -> int:
108
- """Return the current time in milliseconds."""
109
  return int(time.time() * 1000)
110
 
111
-
112
- def read_file_to_text(file: Optional[gr.File]) -> str:
113
- """
114
- Read an uploaded file (txt/md/json) to a string. For JSON files,
115
- return the value of the "transcript" field if present, otherwise
116
- return the entire JSON as a compact string.
117
- """
118
- if not file or not file.name:
119
- return ""
120
- name = file.name.lower()
121
- data = file.read()
122
- if name.endswith(".json"):
123
- try:
124
- obj = json.loads(data.decode("utf-8", errors="ignore"))
125
- if isinstance(obj, dict) and "transcript" in obj:
126
- return str(obj["transcript"])
127
- return json.dumps(obj, ensure_ascii=False)
128
- except Exception:
129
- return data.decode("utf-8", errors="ignore")
130
- return data.decode("utf-8", errors="ignore")
131
-
132
-
133
  def normalize_labels(labels: List[str]) -> List[str]:
134
- """Deduplicate and strip whitespace from a list of labels."""
135
  return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
136
 
137
-
138
  def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
139
- """Map lowercase labels to their canonical names."""
140
  return {lab.lower(): lab for lab in allowed}
141
 
142
-
143
  def robust_json_extract(text: str) -> Dict[str, Any]:
144
- """
145
- Extract the first JSON object from a string. Removes common
146
- trailing comma mistakes. Returns an empty prediction if no JSON
147
- object is found.
148
- """
149
  if not text:
150
  return {"labels": [], "tasks": []}
151
  start, end = text.find("{"), text.rfind("}")
152
- candidate = text[start:end + 1] if (start != -1 and end != -1) else text
153
- candidate = re.sub(r",\s*}\s*", "}", candidate)
154
- candidate = re.sub(r",\s*]\s*", "]", candidate)
155
  try:
156
  return json.loads(candidate)
157
  except Exception:
158
- return {"labels": [], "tasks": []}
159
-
 
 
 
 
160
 
161
  def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
162
- """
163
- Restrict predicted labels and tasks to those in the allowed list.
164
- Case‑insensitive matching is performed and the canonical form is
165
- returned. Duplicates are removed.
166
- """
167
  out = {"labels": [], "tasks": []}
168
  allowed_map = canonicalize_map(allowed)
169
- # Filter labels
170
- filt_labels: List[str] = []
171
  for l in pred.get("labels", []) or []:
172
- if not isinstance(l, str):
173
- continue
174
- k = l.strip().lower()
175
  if k in allowed_map:
176
  filt_labels.append(allowed_map[k])
177
  filt_labels = normalize_labels(filt_labels)
178
- # Filter tasks
179
  filt_tasks = []
180
  for t in pred.get("tasks", []) or []:
181
  if not isinstance(t, dict):
182
  continue
183
- lbl = t.get("label", "")
184
- k = str(lbl).strip().lower()
185
  if k in allowed_map:
186
  new_t = dict(t)
187
  new_t["label"] = allowed_map[k]
188
  filt_tasks.append(new_t)
189
- # Merge labels from tasks
190
- from_tasks = [tt["label"] for tt in filt_tasks if isinstance(tt.get("label"), str)]
191
- merged = normalize_labels(list(set(filt_labels) | set(from_tasks)))
192
  out["labels"] = merged
193
  out["tasks"] = filt_tasks
194
  return out
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
198
- """
199
- Keep only the last `max_tokens` tokens of a string according to
200
- the provided tokenizer. This is useful for long transcripts to
201
- reduce inference time.
202
- """
203
- if max_tokens <= 0:
204
- return text
205
- tok = tokenizer(text, add_special_tokens=False)["input_ids"]
206
- if len(tok) <= max_tokens:
207
  return text
208
- return tokenizer.decode(tok[-max_tokens:], skip_special_tokens=True)
209
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- # =============================================================================
212
- # Model Wrapper and Loader
213
- # =============================================================================
 
 
214
 
 
 
 
215
  class ModelWrapper:
216
- """
217
- Wraps a HuggingFace model and tokenizer, with optional 4‑bit
218
- quantisation. Instances are cached per model and quantisation
219
- setting.
220
- """
221
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
222
  self.repo_id = repo_id
223
  self.hf_token = hf_token
224
  self.load_in_4bit = load_in_4bit
225
- self.tokenizer: Optional[AutoTokenizer] = None
226
- self.model: Optional[AutoModelForCausalLM] = None
227
 
228
  def load(self):
229
- # 4‑bit quantisation config
230
  qcfg = None
231
  if self.load_in_4bit and DEVICE == "cuda":
232
  qcfg = BitsAndBytesConfig(
@@ -235,527 +152,433 @@ class ModelWrapper:
235
  bnb_4bit_compute_dtype=torch.float16,
236
  bnb_4bit_use_double_quant=True,
237
  )
238
- # Tokenizer
239
- self.tokenizer = AutoTokenizer.from_pretrained(
240
- self.repo_id,
241
- token=self.hf_token,
242
- cache_dir=str(SPACE_CACHE),
243
- trust_remote_code=True,
244
- use_fast=True,
245
  )
246
- if self.tokenizer.pad_token is None and self.tokenizer.eos_token:
247
- self.tokenizer.pad_token = self.tokenizer.eos_token
248
- # Model
249
- self.model = AutoModelForCausalLM.from_pretrained(
250
- self.repo_id,
251
- token=self.hf_token,
252
- cache_dir=str(SPACE_CACHE),
253
  trust_remote_code=True,
254
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
255
  device_map="auto" if DEVICE == "cuda" else None,
256
- low_cpu_mem_usage=True,
257
- quantization_config=qcfg,
258
  attn_implementation="sdpa",
259
  )
 
 
260
 
261
  @torch.inference_mode()
262
  def generate(self, system_prompt: str, user_prompt: str) -> str:
263
- """
264
- Generate text from system and user prompts. Chat templates are
265
- used if defined on the tokenizer.
266
- """
267
- assert self.tokenizer is not None and self.model is not None
268
- # Chat template support
269
  if hasattr(self.tokenizer, "apply_chat_template"):
270
- messages = [
271
- {"role": "system", "content": system_prompt},
272
- {"role": "user", "content": user_prompt},
273
- ]
274
- input_ids = self.tokenizer.apply_chat_template(
275
- messages, add_generation_prompt=True, return_tensors="pt"
276
  ).to(self.model.device)
277
  else:
278
- text = f"<s>[SYSTEM]{system_prompt}[/SYSTEM][USER]{user_prompt}[/USER]"
279
- input_ids = self.tokenizer(text, return_tensors="pt").to(self.model.device)
 
280
  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
281
  out_ids = self.model.generate(
282
- **input_ids,
283
  generation_config=GEN_CONFIG,
284
  eos_token_id=self.tokenizer.eos_token_id,
285
  pad_token_id=self.tokenizer.pad_token_id,
286
  )
287
  return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
288
 
289
-
290
- # Model cache keyed by (repo_id, quantisation)
291
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
292
-
293
-
294
  def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
295
- """Retrieve or load a ModelWrapper from the cache."""
296
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
297
  if key not in _MODEL_CACHE:
298
- mw = ModelWrapper(repo_id, hf_token, load_in_4bit)
299
- mw.load()
300
- _MODEL_CACHE[key] = mw
301
  return _MODEL_CACHE[key]
302
 
303
-
304
- # =============================================================================
305
- # Evaluation Function
306
- # =============================================================================
307
-
308
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
309
- """
310
- Official weighted score for the challenge【235032860356166†L76-L90】. False negatives
311
- incur double the penalty of false positives. Returns a score
312
- between 0.0 and 1.0, where 1.0 is perfect.
313
- """
314
- ALLOWED_LABELS = DEFAULT_ALLOWED_LABELS
315
  LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
316
  FN_PENALTY = 2.0
317
  FP_PENALTY = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  if len(y_true) != len(y_pred):
319
- raise ValueError(f"y_true and y_pred lengths differ: {len(y_true)} vs {len(y_pred)}")
 
320
  n_samples = len(y_true)
321
- n_labels = len(ALLOWED_LABELS)
322
  y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
323
  y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
324
- for i, labels in enumerate(y_true):
325
- for l in labels:
326
- if l not in LABEL_TO_IDX:
327
- raise ValueError(f"Invalid true label '{l}'")
328
- y_true_binary[i, LABEL_TO_IDX[l]] = 1
329
- for i, labels in enumerate(y_pred):
330
- for l in labels:
331
- if l not in LABEL_TO_IDX:
332
- raise ValueError(f"Invalid predicted label '{l}'")
333
- y_pred_binary[i, LABEL_TO_IDX[l]] = 1
334
- false_negatives = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
335
- false_positives = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
336
- weighted_errors = FN_PENALTY * false_negatives + FP_PENALTY * false_positives
337
- max_errors_per_sample = FN_PENALTY * np.sum(y_true_binary, axis=1) + FP_PENALTY * (
338
- n_labels - np.sum(y_true_binary, axis=1)
339
- )
340
- per_sample_scores = np.where(
341
- max_errors_per_sample > 0,
342
- 1.0 - (weighted_errors / max_errors_per_sample),
343
- 1.0,
344
- )
345
- final_score = float(np.mean(per_sample_scores))
346
- return max(0.0, min(1.0, final_score))
347
-
348
-
349
- # =============================================================================
350
- # Prediction Utilities
351
- # =============================================================================
352
-
353
- def predict_labels_for_text(
354
- model: ModelWrapper,
355
- transcript: str,
356
- allowed: List[str],
357
- max_tokens: int,
358
- ) -> List[str]:
359
- """
360
- Predict labels for a transcript string using the given model.
361
- The transcript is truncated to the last `max_tokens` tokens to
362
- reduce inference time. Only labels in `allowed` are returned.
363
- """
364
- # Truncate transcript
365
- truncated = truncate_tokens(model.tokenizer, transcript, max_tokens)
366
- allowed_list_str = "\n".join(f"- {lab}" for lab in allowed)
367
- user_prompt = USER_PROMPT_TEMPLATE.format(
368
- transcript=truncated,
369
- allowed_labels_list=allowed_list_str,
370
- )
371
- raw_out = model.generate(SYSTEM_PROMPT, user_prompt)
372
- parsed = robust_json_extract(raw_out)
373
- filtered = restrict_to_allowed(parsed, allowed)
374
- return filtered.get("labels", []) or []
375
 
376
-
377
- # =============================================================================
378
- # Single Transcript Handler
379
- # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
  def run_single(
382
  transcript_text: str,
383
- transcript_file: Optional[gr.File],
 
384
  allowed_labels_text: str,
385
  model_repo: str,
386
  use_4bit: bool,
387
  max_input_tokens: int,
388
  hf_token: str,
389
  ) -> Tuple[str, str, str, str]:
390
- """
391
- Process a single transcript and return (summary, json_output,
392
- diagnostics, raw_model_output). The summary is human‑readable,
393
- json_output is the strict JSON string, diagnostics contains
394
- performance information, and raw_model_output is the unfiltered
395
- model response for debugging.
396
- """
397
  t0 = _now_ms()
398
- # Determine transcript text
399
- raw_text = read_file_to_text(transcript_file) if transcript_file else (transcript_text or "")
400
- raw_text = raw_text.strip()
 
401
  if not raw_text:
402
- return (
403
- "",
404
- "",
405
- "No transcript provided.",
406
- json.dumps({"labels": [], "tasks": []}, indent=2),
407
- )
408
- # Determine allowed labels
409
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
410
- allowed = normalize_labels(user_allowed or DEFAULT_ALLOWED_LABELS)
411
- # Load model
 
412
  try:
413
- model = get_model(model_repo, hf_token.strip() or None, use_4bit)
414
  except Exception as e:
415
- return (
416
- "",
417
- "",
418
- f"Model load failed: {e}",
419
- json.dumps({"labels": [], "tasks": []}, indent=2),
420
- )
421
- t1 = _now_ms()
422
- # Truncate transcript
423
- truncated = truncate_tokens(model.tokenizer, raw_text, max_input_tokens)
424
- allowed_list_str = "\n".join(f"- {lab}" for lab in allowed)
425
  user_prompt = USER_PROMPT_TEMPLATE.format(
426
- transcript=truncated,
427
  allowed_labels_list=allowed_list_str,
 
428
  )
 
429
  # Generate
 
430
  try:
431
- model_out = model.generate(SYSTEM_PROMPT, user_prompt)
432
  except Exception as e:
433
- return (
434
- "",
435
- "",
436
- f"Generation error: {e}",
437
- json.dumps({"labels": [], "tasks": []}, indent=2),
438
- )
439
  t2 = _now_ms()
440
- # Parse and filter
441
- parsed = robust_json_extract(model_out)
 
442
  filtered = restrict_to_allowed(parsed, allowed)
443
- # Compose summary
 
 
 
 
 
 
 
 
 
 
 
444
  labs = filtered.get("labels", [])
445
  tasks = filtered.get("tasks", [])
446
- summ_lines: List[str] = []
447
- if labs:
448
- summ_lines.append("Detected labels:\n - " + "\n - ".join(labs))
449
- else:
450
- summ_lines.append("Detected labels: (none)")
451
  if tasks:
452
- summ_lines.append("\nTasks:")
453
- for t in tasks:
454
- lab = t.get("label", "")
455
- expl = t.get("explanation", "")
456
- ev = t.get("evidence", "")
457
- trimmed = ev[:140] + ("…" if len(ev) > 140 else "")
458
- summ_lines.append(f"• [{lab}] {expl} | evidence: {trimmed}")
459
  else:
460
- summ_lines.append("\nTasks: (none)")
461
- summary = "\n".join(summ_lines)
462
- # Diagnostics
463
- diag = [
464
- f"Device: {DEVICE} (4‑bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
465
- f"Model: {model_repo}",
466
- f"Tokens (input ≤): {max_input_tokens}",
467
- f"Latency: load/prep {t1 - t0} ms, generate {t2 - t1} ms, total {t2 - t0} ms",
468
- f"Allowed labels (n={len(allowed)}): {', '.join(allowed)}",
469
- ]
470
- diag_str = "\n".join(diag)
471
- json_str = json.dumps(filtered, ensure_ascii=False, indent=2)
472
- raw_out = model_out.strip()
473
- return summary, json_str, diag_str, raw_out
474
-
475
-
476
- # =============================================================================
477
- # Batch Evaluation Handler
478
- # =============================================================================
479
 
480
  def run_batch(
481
- zip_file: Optional[gr.File],
482
- allowed_labels_text: str,
483
  model_repo: str,
484
  use_4bit: bool,
485
  max_input_tokens: int,
486
  hf_token: str,
487
- max_files: int,
488
  ) -> Tuple[str, str, str, pd.DataFrame, str]:
489
- """
490
- Run batch evaluation on a ZIP archive of transcripts and ground
491
- truths. Returns (score_str, recall_precision_f1_str, extra_info,
492
- dataframe, download_path).
493
- """
494
- if zip_file is None or not zip_file.name.lower().endswith(".zip"):
495
- return ("No ZIP file provided.", "", "", pd.DataFrame(), "")
496
- # Allowed labels
497
- user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
498
- allowed = normalize_labels(user_allowed or DEFAULT_ALLOWED_LABELS)
499
- # Load model once
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  try:
501
- model = get_model(model_repo, hf_token.strip() or None, use_4bit)
502
  except Exception as e:
503
  return (f"Model load failed: {e}", "", "", pd.DataFrame(), "")
504
- # Extract ZIP to temp directory
505
- timestamp = int(time.time())
506
- extract_root = Path("/tmp") / f"batch_{timestamp}"
507
- extract_root.mkdir(parents=True, exist_ok=True)
508
- try:
509
- with zipfile.ZipFile(io.BytesIO(zip_file.read())) as zf:
510
- zf.extractall(extract_root)
511
- except Exception as e:
512
- return (f"Failed to extract ZIP: {e}", "", "", pd.DataFrame(), "")
513
- # Collect transcript and label paths
514
- transcript_paths: Dict[str, Path] = {}
515
- truth_paths: Dict[str, Path] = {}
516
- for path in extract_root.rglob("*"):
517
- if path.is_file():
518
- stem = path.stem
519
- ext = path.suffix.lower()
520
- if ext == ".txt":
521
- transcript_paths[stem] = path
522
- elif ext == ".json":
523
- truth_paths[stem] = path
524
- # Pair transcripts and truth files
525
- paired = [
526
- (stem, transcript_paths[stem], truth_paths.get(stem))
527
- for stem in sorted(transcript_paths.keys())
528
- ]
529
- if not paired:
530
- return ("No transcript files found in ZIP.", "", "", pd.DataFrame(), "")
531
- # Optionally limit number of files
532
- if max_files > 0:
533
- paired = paired[: max_files]
534
- # Lists for evaluation and per‑sample results
535
- y_true_list: List[List[str]] = []
536
- y_pred_list: List[List[str]] = []
537
- result_rows: List[Dict[str, Any]] = []
538
- total_tp = total_fp = total_fn = 0
539
- # Iterate through samples
540
- for stem, txt_path, truth_path in paired:
541
- # Read transcript
542
- try:
543
- with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
544
- transcript = f.read().strip()
545
- except Exception:
546
- transcript = ""
547
- # Read ground truth labels
548
- true_labels: List[str] = []
549
- if truth_path and truth_path.is_file():
550
- try:
551
- with open(truth_path, "r", encoding="utf-8", errors="ignore") as f:
552
- obj = json.load(f)
553
- if isinstance(obj, dict) and "labels" in obj:
554
- true_labels = [str(l).strip() for l in obj["labels"] if isinstance(l, str)]
555
- elif isinstance(obj, list):
556
- true_labels = [str(l).strip() for l in obj if isinstance(l, str)]
557
- except Exception:
558
- true_labels = []
559
- # Predict labels
560
- pred_labels: List[str] = []
561
- if transcript:
562
  try:
563
- pred_labels = predict_labels_for_text(model, transcript, allowed, max_input_tokens)
 
 
564
  except Exception:
565
- pred_labels = []
566
- # Compute per‑sample metrics
567
- true_set = set(true_labels)
568
- pred_set = set(pred_labels)
569
- tp = len(true_set & pred_set)
570
- fp = len(pred_set - true_set)
571
- fn = len(true_set - pred_set)
572
- total_tp += tp
573
- total_fp += fp
574
- total_fn += fn
575
- y_true_list.append(list(true_set))
576
- y_pred_list.append(list(pred_set))
577
- result_rows.append(
578
- {
579
- "file": stem,
580
- "true_labels": ", ".join(sorted(true_set)) if true_set else "",
581
- "pred_labels": ", ".join(sorted(pred_set)) if pred_set else "",
582
- "true_positives": tp,
583
- "false_positives": fp,
584
- "false_negatives": fn,
585
- }
586
- )
587
- # Compute metrics
588
- if y_true_list:
589
- try:
590
- weighted_score = evaluate_predictions(y_true_list, y_pred_list)
591
- except Exception:
592
- weighted_score = 0.0
593
- precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 1.0
594
- recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 1.0
595
- f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
596
- else:
597
- weighted_score = precision = recall = f1 = 0.0
598
- df = pd.DataFrame(result_rows)
599
- score_str = f"Weighted score: {weighted_score:.3f}"
600
- metrics_str = f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}"
601
- extra_str = f"Processed {len(paired)} transcripts | TP={total_tp} FP={total_fp} FN={total_fn}"
602
- # Write CSV
603
- csv_path = extract_root / "batch_results.csv"
604
- try:
605
- df.to_csv(csv_path, index=False)
606
- csv_path_str = str(csv_path)
607
- except Exception:
608
- csv_path_str = ""
609
- return (score_str, metrics_str, extra_str, df, csv_path_str)
610
-
611
-
612
- # =============================================================================
613
- # Interface
614
- # =============================================================================
615
-
616
- def build_ui() -> gr.Blocks:
617
- """Construct the Gradio interface."""
618
- with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
619
- gr.Markdown("# Talk2Task – Transcript Task Extraction and Evaluation")
620
- gr.Markdown(
621
- "This app extracts actionable tasks from client–advisor transcripts using "
622
- "a selectable language model. You can run it on a single transcript "
623
- "or evaluate a batch of transcripts against provided ground truth labels."
624
- )
625
- with gr.Tab("Single Transcript"):
626
- with gr.Row():
627
- with gr.Column(scale=3):
628
- transcript_file = gr.File(
629
- label="Upload transcript (.txt/.md/.json)",
630
- file_types=[".txt", ".md", ".json"],
631
- type="filepath",
632
- )
633
- transcript_text = gr.Textbox(
634
- label="Or paste transcript here",
635
- lines=12,
636
- placeholder="Paste conversation transcript…",
637
- )
638
- allowed_labels_text = gr.Textbox(
639
- label="Allowed Labels (one per line; leave blank for defaults)",
640
- lines=8,
641
- )
642
- with gr.Column(scale=2):
643
- model_repo = gr.Dropdown(
644
- label="Model Repository",
645
- choices=[
646
- "swiss-ai/Apertus-8B-Instruct-2509",
647
- "meta-llama/Meta-Llama-3-8B-Instruct",
648
- "mistralai/Mistral-7B-Instruct-v0.3",
649
- ],
650
- value="swiss-ai/Apertus-8B-Instruct-2509",
651
- )
652
- use_4bit = gr.Checkbox(
653
- label="Use 4-bit quantisation (GPU only)", value=True
654
- )
655
- max_input_tokens = gr.Slider(
656
- label="Max input tokens (truncate from end)",
657
- minimum=1024,
658
- maximum=8192,
659
- step=512,
660
- value=4096,
661
- )
662
- hf_token = gr.Textbox(
663
- label="HF_TOKEN (for gated/private models)",
664
- type="password",
665
- value=os.environ.get("HF_TOKEN", ""),
666
- )
667
- single_button = gr.Button("Run Extraction", variant="primary")
668
- with gr.Row():
669
- summary = gr.Textbox(label="Summary", lines=12)
670
- json_out = gr.Code(label="Strict JSON Output", language="json")
671
- with gr.Row():
672
- diag = gr.Textbox(label="Diagnostics", lines=6)
673
- raw_out = gr.Textbox(label="Raw Model Output", lines=6)
674
- # Hook up single button
675
- single_button.click(
676
- fn=run_single,
677
- inputs=[
678
- transcript_text,
679
- transcript_file,
680
- allowed_labels_text,
681
- model_repo,
682
- use_4bit,
683
- max_input_tokens,
684
- hf_token,
685
- ],
686
- outputs=[summary, json_out, diag, raw_out],
687
- )
688
- with gr.Tab("Batch Evaluation"):
689
- with gr.Row():
690
- with gr.Column(scale=3):
691
- zip_input = gr.File(
692
- label="ZIP of transcripts and labels", file_types=[".zip"], type="filepath"
693
- )
694
- batch_allowed_labels = gr.Textbox(
695
- label="Allowed Labels (one per line; leave blank for defaults)",
696
- lines=8,
697
- )
698
- max_files_slider = gr.Slider(
699
- label="Max files to process (0 = no limit)",
700
- minimum=0,
701
- maximum=1000,
702
- step=1,
703
- value=0,
704
- )
705
- with gr.Column(scale=2):
706
- batch_model_repo = gr.Dropdown(
707
- label="Model Repository",
708
- choices=[
709
- "swiss-ai/Apertus-8B-Instruct-2509",
710
- "meta-llama/Meta-Llama-3-8B-Instruct",
711
- "mistralai/Mistral-7B-Instruct-v0.3",
712
- ],
713
- value="swiss-ai/Apertus-8B-Instruct-2509",
714
- )
715
- batch_use_4bit = gr.Checkbox(
716
- label="Use 4-bit quantisation (GPU only)", value=True
717
- )
718
- batch_max_input_tokens = gr.Slider(
719
- label="Max input tokens (truncate from end)",
720
- minimum=1024,
721
- maximum=8192,
722
- step=512,
723
- value=4096,
724
- )
725
- batch_hf_token = gr.Textbox(
726
- label="HF_TOKEN (for gated/private models)",
727
- type="password",
728
- value=os.environ.get("HF_TOKEN", ""),
729
- )
730
- batch_button = gr.Button("Run Batch Evaluation", variant="primary")
731
- # Outputs
732
- batch_score = gr.Textbox(label="Score")
733
- batch_metrics = gr.Textbox(label="Recall / Precision / F1")
734
- batch_extra = gr.Textbox(label="Summary", lines=2)
735
- batch_df = gr.Dataframe(label="Per‑sample results", interactive=True, wrap=True)
736
- batch_download = gr.File(label="Download results (CSV)")
737
- # Hook up batch button
738
- def on_batch(zip_file, allowed_text, repo, use4, max_tok, token, max_f):
739
- score, metrics, extra, df, csv_path = run_batch(
740
- zip_file, allowed_text, repo, use4, max_tok, token, int(max_f)
741
  )
742
- return score, metrics, extra, df, csv_path
743
- batch_button.click(
744
- fn=on_batch,
745
- inputs=[
746
- zip_input,
747
- batch_allowed_labels,
748
- batch_model_repo,
749
- batch_use_4bit,
750
- batch_max_input_tokens,
751
- batch_hf_token,
752
- max_files_slider,
753
- ],
754
- outputs=[batch_score, batch_metrics, batch_extra, batch_df, batch_download],
755
- )
756
- return demo
 
 
 
 
 
 
 
 
 
 
 
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
 
759
  if __name__ == "__main__":
760
- demo = build_ui()
761
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ Allowed Labels (canonical; use only these):
 
 
 
 
 
3
  {allowed_labels_list}
4
 
5
+ Context cues (keywords/phrases that often indicate each label):
6
+ {keyword_context}
 
 
 
 
 
 
 
 
7
 
8
+ Instructions:
9
+ 1) Identify EVERY concrete task implied by the conversation.
10
+ 2) Choose ONE label from Allowed Labels for each task (or none if truly inapplicable).
11
+ 3) Return STRICT JSON only in the exact schema described by the system prompt.
12
+ """
13
 
14
+ # =========================
15
+ # Utilities
16
+ # =========================
17
  def _now_ms() -> int:
 
18
  return int(time.time() * 1000)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def normalize_labels(labels: List[str]) -> List[str]:
 
21
  return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
22
 
 
23
  def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
 
24
  return {lab.lower(): lab for lab in allowed}
25
 
 
26
  def robust_json_extract(text: str) -> Dict[str, Any]:
 
 
 
 
 
27
  if not text:
28
  return {"labels": [], "tasks": []}
29
  start, end = text.find("{"), text.rfind("}")
30
+ candidate = text[start:end+1] if (start != -1 and end != -1 and end > start) else text
 
 
31
  try:
32
  return json.loads(candidate)
33
  except Exception:
34
+ candidate = re.sub(r",\s*}", "}", candidate)
35
+ candidate = re.sub(r",\s*]", "]", candidate)
36
+ try:
37
+ return json.loads(candidate)
38
+ except Exception:
39
+ return {"labels": [], "tasks": []}
40
 
41
  def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
 
 
 
 
 
42
  out = {"labels": [], "tasks": []}
43
  allowed_map = canonicalize_map(allowed)
44
+ # labels
45
+ filt_labels = []
46
  for l in pred.get("labels", []) or []:
47
+ k = str(l).strip().lower()
 
 
48
  if k in allowed_map:
49
  filt_labels.append(allowed_map[k])
50
  filt_labels = normalize_labels(filt_labels)
51
+ # tasks
52
  filt_tasks = []
53
  for t in pred.get("tasks", []) or []:
54
  if not isinstance(t, dict):
55
  continue
56
+ k = str(t.get("label", "")).strip().lower()
 
57
  if k in allowed_map:
58
  new_t = dict(t)
59
  new_t["label"] = allowed_map[k]
60
  filt_tasks.append(new_t)
61
+ merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
 
 
62
  out["labels"] = merged
63
  out["tasks"] = filt_tasks
64
  return out
65
 
66
+ # =========================
67
+ # Default pre-processing
68
+ # =========================
69
+ # These are conservative; they remove boilerplate that appears in many files
70
+ # and does not affect tasks. You can toggle this in the UI.
71
+ _DISCLAIMER_PATTERNS = [
72
+ r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
73
+ r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
74
+ r"(?is)^\s*this message \(including any attachments\).+?(?:\n{2,}|$)",
75
+ ]
76
+ _FOOTER_PATTERNS = [
77
+ r"(?is)\n+kind regards[^\n]*\n.*$", r"(?is)\n+best regards[^\n]*\n.*$",
78
+ r"(?is)\n+sent from my.*$", r"(?is)\n+ubs ag.*$",
79
+ ]
80
+ _TIMESTAMP_SPEAKER = [
81
+ r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02]
82
+ r"^\s*(advisor|client)\s*:\s*", # Advisor: / Client:
83
+ r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1:
84
+ ]
85
 
86
+ def clean_transcript(text: str) -> str:
87
+ if not text:
 
 
 
 
 
 
 
 
88
  return text
89
+ s = text
90
+
91
+ # Remove common timestamps and speaker prefixes (line-wise)
92
+ lines = []
93
+ for ln in s.splitlines():
94
+ ln2 = ln
95
+ for pat in _TIMESTAMP_SPEAKER:
96
+ ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
97
+ lines.append(ln2)
98
+ s = "\n".join(lines)
99
+
100
+ # Remove top disclaimers
101
+ for pat in _DISCLAIMER_PATTERNS:
102
+ s = re.sub(pat, "", s).strip()
103
+
104
+ # Remove trailing footers/signatures
105
+ for pat in _FOOTER_PATTERNS:
106
+ s = re.sub(pat, "", s)
107
+
108
+ # Collapse repeated whitespace
109
+ s = re.sub(r"[ \t]+", " ", s)
110
+ s = re.sub(r"\n{3,}", "\n\n", s).strip()
111
+ return s
112
+
113
+ def read_text_from_file(file: gr.File) -> str:
114
+ if not file or not file.name:
115
+ return ""
116
+ name = file.name.lower()
117
+ data = file.read()
118
+ if name.endswith(".json"):
119
+ try:
120
+ obj = json.loads(data.decode("utf-8", errors="ignore"))
121
+ if isinstance(obj, dict) and "transcript" in obj:
122
+ return str(obj["transcript"])
123
+ return json.dumps(obj, ensure_ascii=False)
124
+ except Exception:
125
+ return data.decode("utf-8", errors="ignore")
126
+ else:
127
+ return data.decode("utf-8", errors="ignore")
128
 
129
+ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
130
+ toks = tokenizer(text, add_special_tokens=False)["input_ids"]
131
+ if len(toks) <= max_tokens:
132
+ return text
133
+ return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
134
 
135
+ # =========================
136
+ # HF model wrapper
137
+ # =========================
138
  class ModelWrapper:
 
 
 
 
 
139
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
140
  self.repo_id = repo_id
141
  self.hf_token = hf_token
142
  self.load_in_4bit = load_in_4bit
143
+ self.tokenizer = None
144
+ self.model = None
145
 
146
  def load(self):
 
147
  qcfg = None
148
  if self.load_in_4bit and DEVICE == "cuda":
149
  qcfg = BitsAndBytesConfig(
 
152
  bnb_4bit_compute_dtype=torch.float16,
153
  bnb_4bit_use_double_quant=True,
154
  )
155
+ tok = AutoTokenizer.from_pretrained(
156
+ self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
157
+ trust_remote_code=True, use_fast=True,
 
 
 
 
158
  )
159
+ if tok.pad_token is None and tok.eos_token:
160
+ tok.pad_token = tok.eos_token
161
+ model = AutoModelForCausalLM.from_pretrained(
162
+ self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
 
 
 
163
  trust_remote_code=True,
164
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
165
  device_map="auto" if DEVICE == "cuda" else None,
166
+ low_cpu_mem_usage=True, quantization_config=qcfg,
 
167
  attn_implementation="sdpa",
168
  )
169
+ self.tokenizer = tok
170
+ self.model = model
171
 
172
  @torch.inference_mode()
173
  def generate(self, system_prompt: str, user_prompt: str) -> str:
 
 
 
 
 
 
174
  if hasattr(self.tokenizer, "apply_chat_template"):
175
+ msgs = [{"role": "system", "content": system_prompt},
176
+ {"role": "user", "content": user_prompt}]
177
+ inputs = self.tokenizer.apply_chat_template(
178
+ msgs, add_generation_prompt=True, return_tensors="pt"
 
 
179
  ).to(self.model.device)
180
  else:
181
+ text = f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n"
182
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
183
+
184
  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
185
  out_ids = self.model.generate(
186
+ **inputs,
187
  generation_config=GEN_CONFIG,
188
  eos_token_id=self.tokenizer.eos_token_id,
189
  pad_token_id=self.tokenizer.pad_token_id,
190
  )
191
  return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
192
 
 
 
193
  _MODEL_CACHE: Dict[str, ModelWrapper] = {}
 
 
194
  def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
 
195
  key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
196
  if key not in _MODEL_CACHE:
197
+ m = ModelWrapper(repo_id, hf_token, load_in_4bit)
198
+ m.load()
199
+ _MODEL_CACHE[key] = m
200
  return _MODEL_CACHE[key]
201
 
202
+ # =========================
203
+ # Official evaluation (from README)
204
+ # =========================
 
 
205
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
206
+ ALLOWED_LABELS = OFFICIAL_LABELS
 
 
 
 
 
207
  LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
208
  FN_PENALTY = 2.0
209
  FP_PENALTY = 1.0
210
+
211
+ def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
212
+ if not isinstance(sample_labels, list):
213
+ raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
214
+ # dedupe
215
+ seen, uniq = set(), []
216
+ for label in sample_labels:
217
+ if not isinstance(label, str):
218
+ raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
219
+ if label in seen:
220
+ raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
221
+ seen.add(label); uniq.append(label)
222
+ # validity
223
+ valid = []
224
+ for label in uniq:
225
+ if label not in ALLOWED_LABELS:
226
+ raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
227
+ valid.append(label)
228
+ return valid
229
+
230
  if len(y_true) != len(y_pred):
231
+ raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
232
+
233
  n_samples = len(y_true)
234
+ n_labels = len(OFFICIAL_LABELS)
235
  y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
236
  y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ for i, sample_labels in enumerate(y_true):
239
+ for label in _process_sample_labels(sample_labels, f"y_true[{i}]"):
240
+ y_true_binary[i, LABEL_TO_IDX[label]] = 1
241
+
242
+ for i, sample_labels in enumerate(y_pred):
243
+ for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
244
+ y_pred_binary[i, LABEL_TO_IDX[label]] = 1
245
+
246
+ fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
247
+ fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
248
+ weighted = 2.0 * fn + 1.0 * fp
249
+ max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
250
+ per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
251
+ return float(max(0.0, min(1.0, np.mean(per_sample))))
252
+
253
+ # =========================
254
+ # Inference helpers
255
+ # =========================
256
+ def build_keyword_context(allowed: List[str]) -> str:
257
+ parts = []
258
+ for lab in allowed:
259
+ kws = LABEL_KEYWORDS.get(lab, [])
260
+ if kws:
261
+ parts.append(f"- {lab}: " + ", ".join(kws))
262
+ else:
263
+ parts.append(f"- {lab}: (no default cues)")
264
+ return "\n".join(parts)
265
 
266
  def run_single(
267
  transcript_text: str,
268
+ transcript_file: gr.File,
269
+ use_cleaning: bool,
270
  allowed_labels_text: str,
271
  model_repo: str,
272
  use_4bit: bool,
273
  max_input_tokens: int,
274
  hf_token: str,
275
  ) -> Tuple[str, str, str, str]:
276
+
 
 
 
 
 
 
277
  t0 = _now_ms()
278
+
279
+ # Get transcript
280
+ raw_text = read_text_from_file(transcript_file) if transcript_file else (transcript_text or "")
281
+ raw_text = (raw_text or "").strip()
282
  if not raw_text:
283
+ return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, indent=2)
284
+
285
+ # Cleaning
286
+ text = clean_transcript(raw_text) if use_cleaning else raw_text
287
+
288
+ # Allowed labels
 
289
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
290
+ allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
291
+
292
+ # Model
293
  try:
294
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
295
  except Exception as e:
296
+ return "", "", f"Model load failed: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
297
+
298
+ # Truncate
299
+ trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
300
+
301
+ # Build prompt
302
+ allowed_list_str = "\n".join(f"- {l}" for l in allowed)
303
+ keyword_ctx = build_keyword_context(allowed)
 
 
304
  user_prompt = USER_PROMPT_TEMPLATE.format(
305
+ transcript=trunc,
306
  allowed_labels_list=allowed_list_str,
307
+ keyword_context=keyword_ctx,
308
  )
309
+
310
  # Generate
311
+ t1 = _now_ms()
312
  try:
313
+ out = model.generate(SYSTEM_PROMPT, user_prompt)
314
  except Exception as e:
315
+ return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
 
 
 
 
 
316
  t2 = _now_ms()
317
+
318
+ # Parse + filter
319
+ parsed = robust_json_extract(out)
320
  filtered = restrict_to_allowed(parsed, allowed)
321
+
322
+ # Diagnostics
323
+ diag = "\n".join([
324
+ f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
325
+ f"Model: {model_repo}",
326
+ f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
327
+ f"Tokens (input, approx): ≤ {max_input_tokens}",
328
+ f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
329
+ f"Allowed labels: {', '.join(allowed)}",
330
+ ])
331
+
332
+ # Summary
333
  labs = filtered.get("labels", [])
334
  tasks = filtered.get("tasks", [])
335
+ summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
 
 
 
 
336
  if tasks:
337
+ summary += "\n\nTasks:\n" + "\n".join(
338
+ f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:140]}{'…' if len(t.get('evidence',''))>140 else ''}"
339
+ for t in tasks
340
+ )
 
 
 
341
  else:
342
+ summary += "\n\nTasks: (none)"
343
+
344
+ return summary, json.dumps(filtered, indent=2, ensure_ascii=False), diag, out.strip()
345
+
346
+ # =========================
347
+ # Batch mode (ZIP with transcripts + truths)
348
+ # =========================
349
+ def read_zip(fileobj: io.BytesIO, exdir: Path) -> List[Path]:
350
+ exdir.mkdir(parents=True, exist_ok=True)
351
+ with zipfile.ZipFile(fileobj) as zf:
352
+ zf.extractall(exdir)
353
+ out = []
354
+ for p in exdir.rglob("*"):
355
+ if p.is_file():
356
+ out.append(p)
357
+ return out
 
 
 
358
 
359
  def run_batch(
360
+ zip_file: gr.File,
361
+ use_cleaning: bool,
362
  model_repo: str,
363
  use_4bit: bool,
364
  max_input_tokens: int,
365
  hf_token: str,
366
+ limit_files: int,
367
  ) -> Tuple[str, str, str, pd.DataFrame, str]:
368
+
369
+ if not zip_file:
370
+ return ("No ZIP provided.", "", "", pd.DataFrame(), "")
371
+
372
+ work = Path("/tmp/batch")
373
+ if work.exists():
374
+ for p in work.rglob("*"):
375
+ try: p.unlink()
376
+ except Exception: pass
377
+ try: work.rmdir()
378
+ except Exception: pass
379
+ work.mkdir(parents=True, exist_ok=True)
380
+
381
+ # Unzip
382
+ data = zip_file.read()
383
+ files = read_zip(io.BytesIO(data), work)
384
+
385
+ # Gather pairs by stem
386
+ txts: Dict[str, Path] = {}
387
+ gts: Dict[str, Path] = {}
388
+ for p in files:
389
+ if p.suffix.lower() == ".txt":
390
+ txts[p.stem] = p
391
+ elif p.suffix.lower() == ".json":
392
+ gts[p.stem] = p
393
+
394
+ stems = sorted(txts.keys())
395
+ if limit_files > 0:
396
+ stems = stems[:limit_files]
397
+ if not stems:
398
+ return ("No .txt transcripts found in ZIP.", "", "", pd.DataFrame(), "")
399
+
400
+ # Model
401
  try:
402
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
403
  except Exception as e:
404
  return (f"Model load failed: {e}", "", "", pd.DataFrame(), "")
405
+
406
+ allowed = OFFICIAL_LABELS[:] # fixed for scoring
407
+ allowed_list_str = "\n".join(f"- {l}" for l in allowed)
408
+ keyword_ctx = build_keyword_context(allowed)
409
+
410
+ y_true, y_pred = [], []
411
+ rows = []
412
+ t_start = _now_ms()
413
+
414
+ for stem in stems:
415
+ raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
416
+ text = clean_transcript(raw) if use_cleaning else raw
417
+ trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
418
+
419
+ user_prompt = USER_PROMPT_TEMPLATE.format(
420
+ transcript=trunc,
421
+ allowed_labels_list=allowed_list_str,
422
+ keyword_context=keyword_ctx,
423
+ )
424
+
425
+ t0 = _now_ms()
426
+ out = model.generate(SYSTEM_PROMPT, user_prompt)
427
+ t1 = _now_ms()
428
+
429
+ parsed = robust_json_extract(out)
430
+ filtered = restrict_to_allowed(parsed, allowed)
431
+ pred_labels = filtered.get("labels", [])
432
+ y_pred.append(pred_labels)
433
+
434
+ # Ground truth (optional)
435
+ gt_labels = []
436
+ if stem in gts:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  try:
438
+ gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
439
+ if isinstance(gt_obj, dict) and "labels" in gt_obj and isinstance(gt_obj["labels"], list):
440
+ gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
441
  except Exception:
442
+ pass
443
+ y_true.append(gt_labels)
444
+
445
+ # FP/FN counts for table
446
+ gt_set = set(gt_labels)
447
+ pr_set = set(pred_labels)
448
+ tp = sorted(gt_set & pr_set)
449
+ fp = sorted(pr_set - gt_set)
450
+ fn = sorted(gt_set - pr_set)
451
+
452
+ rows.append({
453
+ "file": stem,
454
+ "true_labels": ", ".join(gt_labels),
455
+ "pred_labels": ", ".join(pred_labels),
456
+ "TP": len(tp), "FP": len(fp), "FN": len(fn),
457
+ "gen_ms": t1 - t0
458
+ })
459
+
460
+ # Metrics
461
+ # If there is no ground truth in the ZIP, we still compute a table and skip score.
462
+ have_truth = any(len(v) > 0 for v in y_true)
463
+ score = evaluate_predictions(y_true, y_pred) if have_truth else None
464
+
465
+ df = pd.DataFrame(rows).sort_values(["FN", "FP", "file"])
466
+ diag = [
467
+ f"Processed files: {len(stems)}",
468
+ f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
469
+ f"Model: {model_repo}",
470
+ f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
471
+ f"Tokens (input, approx): {max_input_tokens}",
472
+ f"Batch time: {_now_ms()-t_start} ms",
473
+ ]
474
+ if have_truth and score is not None:
475
+ # Simple derived metrics
476
+ total_tp = int(df["TP"].sum())
477
+ total_fp = int(df["FP"].sum())
478
+ total_fn = int(df["FN"].sum())
479
+ recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 1.0
480
+ precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 1.0
481
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
482
+ diag += [
483
+ f"Official weighted score (0–1): {score:.3f}",
484
+ f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}",
485
+ f"Total TP={total_tp} FP={total_fp} FN={total_fn}",
486
+ ]
487
+ diag_str = "\n".join(diag)
488
+
489
+ # CSV preview and data URL
490
+ csv_buf = io.StringIO()
491
+ df.to_csv(csv_buf, index=False)
492
+ csv_data = csv_buf.getvalue()
493
+
494
+ return ("Batch done.", diag_str, csv_data, df, csv_data)
495
+
496
+ # =========================
497
+ # UI
498
+ # =========================
499
+ MODEL_CHOICES = [
500
+ "swiss-ai/Apertus-8B-Instruct-2509",
501
+ "meta-llama/Meta-Llama-3-8B-Instruct",
502
+ "mistralai/Mistral-7B-Instruct-v0.3",
503
+ ]
504
+
505
+ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
506
+ gr.Markdown("# Talk2Task — Task Extraction (UBS Challenge)")
507
+ gr.Markdown(
508
+ "This tool extracts challenge labels from transcripts. "
509
+ "Use **Single** for quick tests; use **Batch** to score a ZIP with transcripts + truths. "
510
+ "_Note: False negatives are penalised twice as much as false positives in the official metric; "
511
+ "we bias for recall._"
512
+ )
513
+
514
+ with gr.Tab("Single transcript"):
515
+ with gr.Row():
516
+ with gr.Column(scale=3):
517
+ file = gr.File(
518
+ label="Drag & drop transcript (.txt / .md / .json)",
519
+ file_types=[".txt", ".md", ".json"],
520
+ type="filepath",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  )
522
+ text = gr.Textbox(label="Or paste transcript", lines=14)
523
+ use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, footers)", value=True)
524
+ labels_text = gr.Textbox(
525
+ label="Allowed Labels (one per line; leave empty to use official list)",
526
+ value="",
527
+ lines=8,
528
+ )
529
+ with gr.Column(scale=2):
530
+ repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
531
+ use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
532
+ max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
533
+ hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
534
+ run_btn = gr.Button("Run Extraction", variant="primary")
535
+
536
+ with gr.Row():
537
+ summary = gr.Textbox(label="Summary", lines=12)
538
+ json_out = gr.Code(label="Strict JSON Output", language="json")
539
+ with gr.Row():
540
+ diag = gr.Textbox(label="Diagnostics", lines=8)
541
+ raw = gr.Textbox(label="Raw Model Output", lines=8)
542
+
543
+ run_btn.click(
544
+ fn=run_single,
545
+ inputs=[text, file, use_cleaning, labels_text, repo, use_4bit, max_tokens, hf_token],
546
+ outputs=[summary, json_out, diag, raw],
547
+ )
548
 
549
+ with gr.Tab("Batch evaluation"):
550
+ with gr.Row():
551
+ with gr.Column(scale=3):
552
+ zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
553
+ use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
554
+ with gr.Column(scale=2):
555
+ repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
556
+ use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
557
+ max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
558
+ hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
559
+ limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
560
+ run_batch_btn = gr.Button("Run Batch", variant="primary")
561
+
562
+ with gr.Row():
563
+ status = gr.Textbox(label="Status", lines=1)
564
+ diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=10)
565
+
566
+ with gr.Row():
567
+ df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, times)", interactive=False)
568
+ csv_out = gr.File(label="Download CSV (click to save)", interactive=False)
569
+
570
+ def _save_csv(csv_text: str) -> str:
571
+ if not csv_text:
572
+ return ""
573
+ out_path = Path("/tmp/batch_results.csv")
574
+ out_path.write_text(csv_text, encoding="utf-8")
575
+ return str(out_path)
576
+
577
+ run_batch_btn.click(
578
+ fn=run_batch,
579
+ inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
580
+ outputs=[status, diag_b, csv_out, df_out, gr.Textbox(visible=False)],
581
+ )
582
 
583
  if __name__ == "__main__":
584
+ demo.launch()