RishiRP commited on
Commit
55c81fd
·
verified ·
1 Parent(s): 24080d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +860 -341
app.py CHANGED
@@ -1,24 +1,43 @@
1
- import os as _os
2
- # Guard: prevent invalid OMP_NUM_THREADS setting
3
- if not _os.environ.get("OMP_NUM_THREADS"):
4
- _os.environ["OMP_NUM_THREADS"] = "1"
5
-
6
  import os
 
 
7
  import json
8
- from typing import Optional, Tuple, Dict, Any, List
 
 
 
9
 
 
 
10
  import gradio as gr
11
- import torch
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
- from langdetect import detect, DetectorFactory
14
 
15
- # Make langdetect deterministic
16
- DetectorFactory.seed = 7
 
 
 
 
 
17
 
18
  # =========================
19
- # Challenge: allowed labels (from UBS repo)
20
  # =========================
21
- ALLOWED_LABELS = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  "plan_contact",
23
  "schedule_meeting",
24
  "update_contact_info_non_postal",
@@ -28,397 +47,897 @@ ALLOWED_LABELS = [
28
  "update_kyc_purpose_of_businessrelation",
29
  "update_kyc_total_assets",
30
  ]
 
31
 
32
  # =========================
33
- # Models / Defaults
34
  # =========================
35
- DEFAULT_MODEL_ID = os.environ.get("MODEL_ID", "Apertus/Apertus-8B")
36
- SUPPORTED_MODELS = [
37
- "swiss-ai/Apertus-8B-Instruct-2509",
38
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
39
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
40
- ]
41
-
42
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def _has_bnb_and_cuda() -> bool:
45
- if DEVICE != "cuda":
46
- return False
47
- try:
48
- import bitsandbytes as _bnb # noqa: F401
49
- return True
50
- except Exception:
51
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- USE_BNB = _has_bnb_and_cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # =========================
56
- # Model cache
57
  # =========================
58
- _tokenizer: Optional[AutoTokenizer] = None
59
- _model: Optional[AutoModelForCausalLM] = None
60
- _current_model_id: Optional[str] = None
61
-
62
- def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
63
- global _tokenizer, _model, _current_model_id
64
-
65
- if _tokenizer is not None and _model is not None and _current_model_id == model_id:
66
- return _tokenizer, _model
67
-
68
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
69
-
70
- if USE_BNB:
71
- from transformers import BitsAndBytesConfig
72
- quant = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
73
- model = AutoModelForCausalLM.from_pretrained(
74
- model_id,
75
- quantization_config=quant,
76
- device_map="auto",
77
- trust_remote_code=True,
78
- )
79
- else:
80
- dtype = torch.float16 if DEVICE == "cuda" else torch.float32
81
- model = AutoModelForCausalLM.from_pretrained(
82
- model_id,
83
- torch_dtype=dtype,
84
- low_cpu_mem_usage=True,
85
- trust_remote_code=True,
86
- ).to(DEVICE)
87
 
88
- _tokenizer, _model, _current_model_id = tokenizer, model, model_id
89
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # =========================
92
- # Helpers
93
  # =========================
94
- def read_file(file_obj: Optional[gr.File]) -> Optional[str]:
95
- if not file_obj:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  return None
 
 
 
 
 
97
  try:
98
- with open(file_obj.name, "r", encoding="utf-8", errors="ignore") as f:
99
- return f.read()
100
  except Exception:
101
  return None
102
 
103
- def normalize_txt_input(paste_txt: str, upload_file: Optional[gr.File]) -> str:
104
- return paste_txt.strip() if (paste_txt and paste_txt.strip()) else (read_file(upload_file) or "")
 
 
 
105
 
106
- def normalize_json_input(paste_json: str, upload_file: Optional[gr.File]) -> str:
107
- if paste_json and paste_json.strip():
108
- return paste_json
109
- return read_file(upload_file) or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- def safe_lang_detect(text: str) -> str:
112
- try:
113
- if not text or not text.strip():
114
- return "unknown"
115
- return detect(text)
116
- except Exception:
117
- return "unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- def count_tokens(tokenizer: AutoTokenizer, text: str) -> int:
120
- try:
121
- return len(tokenizer(text, return_tensors=None).get("input_ids", []))
122
- except Exception:
123
- return max(1, len(text.split()))
 
 
 
 
 
 
 
124
 
125
  # =========================
126
- # Evaluation function (from repo)
127
  # =========================
128
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
129
- import numpy as np
130
-
131
  LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
132
- FN_PENALTY = 2.0
133
- FP_PENALTY = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if len(y_true) != len(y_pred):
136
  raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
137
 
138
  n_samples = len(y_true)
139
- n_labels = len(ALLOWED_LABELS)
140
-
141
  y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
142
  y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
143
 
144
- def _process(sample_labels: List[str], sample_name: str) -> List[str]:
145
- if not isinstance(sample_labels, list):
146
- raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
147
- seen = set()
148
- valid = []
149
- for lbl in sample_labels:
150
- if not isinstance(lbl, str):
151
- raise ValueError(f"{sample_name} contains non-string label: {lbl}")
152
- if lbl in seen:
153
- raise ValueError(f"{sample_name} contains duplicate label: '{lbl}'")
154
- seen.add(lbl)
155
- if lbl not in ALLOWED_LABELS:
156
- raise ValueError(f"{sample_name} contains invalid label: '{lbl}'. Allowed: {ALLOWED_LABELS}")
157
- valid.append(lbl)
158
- return valid
159
-
160
- for i, lbls in enumerate(y_true):
161
- for lbl in _process(lbls, f"y_true[{i}]"):
162
- y_true_binary[i, LABEL_TO_IDX[lbl]] = 1
163
-
164
- for i, lbls in enumerate(y_pred):
165
- for lbl in _process(lbls, f"y_pred[{i}]"):
166
- y_pred_binary[i, LABEL_TO_IDX[lbl]] = 1
167
-
168
- false_negatives = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
169
- false_positives = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
170
- weighted_errors = FN_PENALTY * false_negatives + FP_PENALTY * false_positives
171
- max_errors_per_sample = FN_PENALTY * np.sum(y_true_binary, axis=1) + FP_PENALTY * (n_labels - np.sum(y_true_binary, axis=1))
172
- per_sample_scores = np.where(max_errors_per_sample > 0, 1.0 - (weighted_errors / max_errors_per_sample), 1.0)
173
- return float(np.mean(per_sample_scores))
174
 
175
  # =========================
176
- # Core Extraction
177
  # =========================
178
- def run_extraction(
179
- model_choice: str,
180
- params_checked: list,
181
- instructions_text: str,
182
- context_text: str,
183
- txt_paste: str,
184
- txt_upload: Optional[gr.File],
185
- json_paste: str,
186
- json_upload: Optional[gr.File],
187
- max_new_tokens: int,
188
- temperature: float,
189
- top_p: float,
190
- usd_per_1k_tokens: float,
191
- ) -> Tuple[str, str, str, str, str]:
192
- diagnostics_lines = []
193
-
194
- input_txt = normalize_txt_input(txt_paste, txt_upload)
195
- input_json_raw = normalize_json_input(json_paste, json_upload)
196
-
197
- lang = safe_lang_detect(input_txt)
198
- parsed_json: Dict[str, Any] = {}
199
- json_parse_ok = False
200
- if input_json_raw:
201
- try:
202
- parsed_json = json.loads(input_json_raw)
203
- json_parse_ok = True
204
- except Exception as e:
205
- diagnostics_lines.append(f"JSON parse error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
 
207
  try:
208
- tokenizer, model = load_model(model_choice)
209
  except Exception as e:
210
- diag = "\n".join([
211
- f"Model: {model_choice}",
212
- f"Params: {params_checked}",
213
- f"Language detected: {lang}",
214
- f"TXT length: {len(input_txt)}",
215
- f"JSON parsed: {json_parse_ok}",
216
- f"Model load failed: {e}"
217
- ])
218
- return "", "", "", "", diag
219
-
220
- in_tokens = count_tokens(tokenizer, input_txt) + count_tokens(tokenizer, json.dumps(parsed_json) if parsed_json else "")
221
-
222
- user_prompt = (
223
- "You analyze client-conversation transcripts.\n"
224
- "Transcripts may be multilingual. Detect the language automatically. "
225
- "Extract tasks and entities correctly regardless of language. "
226
- "Always write the short summary in English.\n"
227
- f"Instructions: {instructions_text}\n"
228
- f"Context: {context_text}\n"
229
- "----\n"
230
- f"TEXT:\n{input_txt[:4000]}\n"
231
- "----\n"
232
- f"JSON:\n{json.dumps(parsed_json)[:2000]}\n"
233
- "Output:\n"
234
- "- Tasks list (use allowed labels where possible)\n"
235
- "- Entities list\n"
236
- "- Cleaned text\n"
237
- "- Short summary (English)\n"
238
  )
239
- prompt_tokens = count_tokens(tokenizer, user_prompt)
240
 
 
 
 
 
 
 
 
 
241
  try:
242
- inputs = tokenizer(user_prompt, return_tensors="pt").to(DEVICE)
243
- with torch.no_grad():
244
- outputs = model.generate(
245
- **inputs,
246
- max_new_tokens=max_new_tokens,
247
- do_sample=True,
248
- temperature=temperature,
249
- top_p=top_p,
250
- pad_token_id=tokenizer.eos_token_id,
251
- )
252
- full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
253
  except Exception as e:
254
- diag = "\n".join([
255
- f"Model: {model_choice}",
256
- f"Params: {params_checked}",
257
- f"Language detected: {lang}",
258
- f"TXT length: {len(input_txt)}",
259
- f"JSON parsed: {json_parse_ok}",
260
- f"Inference failed: {e}"
261
- ])
262
- return "", "", "", "", diag
263
-
264
- tasks_out = " plan_contact\n• schedule_meeting"
265
- entities_out = " Client: John Doe\n• Product: Mortgage"
266
- cleaned_out = "Cleaned transcript text here…"
267
- summary_out = "A short English summary of the conversation."
268
-
269
- out_tokens = count_tokens(tokenizer, full_text)
270
- total_tokens = in_tokens + prompt_tokens + out_tokens
271
- est_cost = (total_tokens / 1000.0) * max(0.0, float(usd_per_1k_tokens))
272
-
273
- diagnostics_lines.extend([
274
- f"Model: {model_choice}",
275
- f"Params: {params_checked}",
276
- f"Language detected: {lang}",
277
- f"TXT length: {len(input_txt)}",
278
- f"JSON parsed: {json_parse_ok}",
279
- f"Input tokens: {in_tokens}",
280
- f"Prompt tokens: {prompt_tokens}",
281
- f"Output tokens: {out_tokens}",
282
- f"Total tokens: {total_tokens}",
283
- f"Est. cost @ ${usd_per_1k_tokens:.4f}/1k toks: ${est_cost:.6f}",
284
- "Generation completed successfully.",
285
  ])
286
- diagnostics = "\n".join(diagnostics_lines)
287
 
288
- return tasks_out, entities_out, cleaned_out, summary_out, diagnostics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  # =========================
291
- # Evaluation handler
292
  # =========================
293
- def evaluate_ui(y_true_text: str, y_true_file: Optional[gr.File], y_pred_text: str, y_pred_file: Optional[gr.File]) -> str:
294
- def _load_json(text: str, file_obj: Optional[gr.File]) -> Any:
295
- if text and text.strip():
296
- return json.loads(text)
297
- ftxt = read_file(file_obj)
298
- if ftxt:
299
- return json.loads(ftxt)
300
- raise ValueError("Missing JSON input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  try:
303
- y_true = _load_json(y_true_text, y_true_file)
304
- y_pred = _load_json(y_pred_text, y_pred_file)
305
- score = evaluate_predictions(y_true, y_pred)
306
- return f"Evaluation score: {score:.4f} (higher is better; weighted FN>FP)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  except Exception as e:
308
- return f"Evaluation error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  # =========================
311
- # UI Styling
312
  # =========================
313
- THEME_CSS = """
314
- :root {
315
- --body-background-fill: #ffffff !important;
316
- --body-text-color: #111111 !important;
317
- --link-text-color: #0b63ce !important;
318
- }
319
- .gradio-container, .prose, .prose * { color: #111111 !important; }
320
- label { color: #0b63ce !important; }
321
- button#run-btn {
322
- background: #e11900 !important;
323
- color: #fff !important;
324
- border: 1px solid #b50f00 !important;
325
- }
 
 
 
326
  """
327
 
328
- # =========================
329
- # UI Layout
330
- # =========================
331
- def build_interface() -> gr.Blocks:
332
- with gr.Blocks(title="Talk2Task Demo", css=THEME_CSS) as demo:
333
- with gr.Group():
334
- gr.Markdown("### Model & Parameters")
335
- with gr.Row():
336
- model_choice = gr.Dropdown(
337
- label="Model",
338
- choices=SUPPORTED_MODELS,
339
- value=DEFAULT_MODEL_ID,
340
- scale=3,
341
- )
342
- params_checked = gr.CheckboxGroup(
343
- label="Options",
344
- choices=[
345
- "Default cleaning",
346
- "Remove PII",
347
- "Detect language",
348
- "Use 4-bit if available",
349
- ],
350
- value=["Default cleaning", "Detect language"],
351
- scale=2,
352
- )
353
- with gr.Row():
354
- max_new_tokens = gr.Slider(64, 1024, value=200, step=16, label="Max new tokens")
355
- temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
356
- top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
357
- usd_per_1k_tokens = gr.Number(value=0.002, label="Est. $ per 1k tokens (edit)")
358
-
359
- gr.Markdown("### Input")
360
- with gr.Row(equal_height=True):
361
- with gr.Group():
362
- gr.Markdown("**TXT Input**")
363
- with gr.Tabs():
364
- with gr.TabItem("Paste"):
365
- txt_paste = gr.TextArea(label="Paste TXT", lines=12, placeholder="Paste transcript (any language)…")
366
- with gr.TabItem("Upload"):
367
- txt_upload = gr.File(label="Upload TXT", file_types=[".txt"])
368
- with gr.Group():
369
- gr.Markdown("**JSON Input**")
370
- with gr.Tabs():
371
- with gr.TabItem("Paste"):
372
- json_paste = gr.Code(label="Paste JSON", language="json", value="{\n \"example\": true\n}", lines=12)
373
- with gr.TabItem("Upload"):
374
- json_upload = gr.File(label="Upload JSON", file_types=[".json"])
375
-
376
- run_btn = gr.Button("Run Extraction", elem_id="run-btn")
377
 
 
378
  with gr.Row():
379
- with gr.Accordion("Instructions (editable)", open=False):
380
- instructions_text = gr.TextArea(
381
- value="Extract key tasks, entities, cleaned text, and summary. Be robust to noise; avoid hallucinations.",
382
- lines=5
 
 
383
  )
384
- with gr.Accordion("Context (editable)", open=False):
385
- context_text = gr.TextArea(
386
- value="Banking client-advisor context. Transcripts may be multilingual; always summarize in English.",
387
- lines=5
 
 
 
 
388
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
- gr.Markdown("### Results")
391
- with gr.Row(equal_height=True):
392
- tasks_out = gr.TextArea(label="Tasks", lines=8)
393
- entities_out = gr.TextArea(label="Entities", lines=8)
394
- with gr.Row(equal_height=True):
395
- cleaned_out = gr.TextArea(label="Cleaned Text", lines=8)
396
- summary_out = gr.TextArea(label="Summary (English)", lines=8)
397
-
398
- diagnostics = gr.TextArea(label="Diagnostics", lines=12)
399
-
400
- with gr.Accordion("Evaluation", open=False):
401
- with gr.Row():
402
- y_true_text = gr.Code(label="y_true (JSON)", language="json", lines=10)
403
- y_pred_text = gr.Code(label="y_pred (JSON)", language="json", lines=10)
404
- with gr.Row():
405
- y_true_file = gr.File(label="Upload y_true.json", file_types=[".json"])
406
- y_pred_file = gr.File(label="Upload y_pred.json", file_types=[".json"])
407
- eval_btn = gr.Button("Compute Official Score")
408
- eval_result = gr.Textbox(label="Evaluation Result")
409
- eval_btn.click(evaluate_ui, inputs=[y_true_text, y_true_file, y_pred_text, y_pred_file], outputs=eval_result)
410
-
411
- run_inputs = [
412
- model_choice, params_checked, instructions_text, context_text,
413
- txt_paste, txt_upload, json_paste, json_upload,
414
- max_new_tokens, temperature, top_p, usd_per_1k_tokens
415
- ]
416
- run_outputs = [tasks_out, entities_out, cleaned_out, summary_out, diagnostics]
417
- run_btn.click(fn=run_extraction, inputs=run_inputs, outputs=run_outputs)
 
 
 
 
 
418
 
419
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- demo = build_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  if __name__ == "__main__":
424
  demo.launch()
 
1
+ # app.py
 
 
 
 
2
  import os
3
+ import re
4
+ import io
5
  import json
6
+ import time
7
+ import zipfile
8
+ from pathlib import Path
9
+ from typing import List, Dict, Any, Tuple, Optional
10
 
11
+ import numpy as np
12
+ import pandas as pd
13
  import gradio as gr
 
 
 
14
 
15
+ import torch
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ BitsAndBytesConfig,
20
+ GenerationConfig,
21
+ )
22
 
23
  # =========================
24
+ # Global config
25
  # =========================
26
+ SPACE_CACHE = Path.home() / ".cache" / "huggingface"
27
+ SPACE_CACHE.mkdir(parents=True, exist_ok=True)
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
30
+
31
+ # Deterministic, compact outputs
32
+ GEN_CONFIG = GenerationConfig(
33
+ temperature=0.0,
34
+ top_p=1.0,
35
+ do_sample=False,
36
+ max_new_tokens=128, # raise if your JSON truncates
37
+ )
38
+
39
+ # Canonical labels (UBS)
40
+ OFFICIAL_LABELS = [
41
  "plan_contact",
42
  "schedule_meeting",
43
  "update_contact_info_non_postal",
 
47
  "update_kyc_purpose_of_businessrelation",
48
  "update_kyc_total_assets",
49
  ]
50
+ OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
51
 
52
  # =========================
53
+ # Editable defaults (shown in UI)
54
  # =========================
55
+ DEFAULT_SYSTEM_INSTRUCTIONS = (
56
+ "You extract ACTIONABLE TASKS from client–advisor transcripts. "
57
+ "The transcript may be in German, French, Italian, or English. "
58
+ "Prioritize RECALL: if a label plausibly applies, include it. "
59
+ "Use ONLY the canonical labels provided. "
60
+ "Return STRICT JSON only with keys 'labels' and 'tasks'. "
61
+ "Each task must include 'label', a brief 'explanation', and a short 'evidence' quote from the transcript."
62
+ )
63
+
64
+ DEFAULT_LABEL_GLOSSARY = {
65
+ "plan_contact": "Commitment to contact later (advisor/client will reach out, follow-up promised).",
66
+ "schedule_meeting": "Scheduling or confirming a meeting/call/appointment (time/date/slot/virtual).",
67
+ "update_contact_info_non_postal": "Change or confirmation of phone/email (non-postal contact details).",
68
+ "update_contact_info_postal_address": "Change or confirmation of postal/residential/mailing address.",
69
+ "update_kyc_activity": "Change/confirmation of occupation, employment status, or economic activity.",
70
+ "update_kyc_origin_of_assets": "Discussion/confirmation of source of funds / origin of assets.",
71
+ "update_kyc_purpose_of_businessrelation": "Purpose of the banking relationship/account usage.",
72
+ "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
73
+ }
74
 
75
+ # Minimal multilingual fallback rules (optional)
76
+ DEFAULT_FALLBACK_CUES = {
77
+ "plan_contact": [
78
+ r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b", r"\bfollow\s*up\b", r"\breach out\b", r"\btouch base\b",
79
+ r"\bcontact (you|me|us)\b",
80
+ r"\bin verbindung setzen\b", r"\brückmeldung\b", r"\bich\s+melde\b|\bwir\s+melden\b", r"\bnachfassen\b",
81
+ r"\bje vous recontacte\b|\bnous vous recontacterons\b", r"\bprendre contact\b|\breprendre contact\b",
82
+ r"\bla ricontatter[oò]\b|\bci metteremo in contatto\b", r"\btenersi in contatto\b",
83
+ ],
84
+ "schedule_meeting": [
85
+ r"\b(let'?s\s+)?meet(ing|s)?\b", r"\bschedule( a)? (call|meeting|appointment)\b",
86
+ r"\bbook( a)? (slot|time|meeting)\b", r"\b(next week|tomorrow|this (afternoon|morning|evening))\b",
87
+ r"\bconfirm( the)? (time|meeting|appointment)\b",
88
+ r"\btermin(e|s)?\b|\bvereinbaren\b|\bansetzen\b|\babstimmen\b|\bbesprechung(en)?\b|\bvirtuell(e|en)?\b",
89
+ r"\bnächste(n|r)? woche\b|\b(dienstag|montag|mittwoch|donnerstag|freitag)\b|\bnachmittag|vormittag|morgen\b",
90
+ r"\brendez[- ]?vous\b|\bréunion\b|\bfixer\b|\bplanifier\b|\bse rencontrer\b|\bse voir\b",
91
+ r"\bla semaine prochaine\b|\bdemain\b|\bcet (après-midi|apres-midi|après midi|apres midi|matin|soir)\b",
92
+ r"\bappuntamento\b|\briunione\b|\borganizzare\b|\bprogrammare\b|\bincontrarci\b|\bcalendario\b",
93
+ r"\bla prossima settimana\b|\bdomani\b|\b(questo|questa)\s*(pomeriggio|mattina|sera)\b",
94
+ ],
95
+ "update_kyc_origin_of_assets": [
96
+ r"\bsource of funds\b|\borigin of assets\b|\bproof of (funds|assets)\b",
97
+ r"\bvermögensursprung(e|s)?\b|\bherkunft der mittel\b|\bnachweis\b",
98
+ r"\borigine des fonds\b|\borigine du patrimoine\b|\bjustificatif(s)?\b",
99
+ r"\borigine dei fondi\b|\borigine del patrimonio\b|\bprova dei fondi\b|\bgiustificativo\b",
100
+ ],
101
+ "update_kyc_activity": [
102
+ r"\bemployment status\b|\boccupation\b|\bjob change\b|\bsalary history\b",
103
+ r"\bbeschäftigungsstatus\b|\bberuf\b|\bjobwechsel\b|\bgehaltshistorie\b|\btätigkeit\b",
104
+ r"\bstatut professionnel\b|\bprofession\b|\bchangement d'emploi\b|\bhistorique salarial\b|\bactivité\b",
105
+ r"\bstato occupazionale\b|\bprofessione\b|\bcambio di lavoro\b|\bstoria salariale\b|\battivit[aà]\b",
106
+ ],
107
+ }
108
 
109
+ # =========================
110
+ # Prompt template
111
+ # =========================
112
+ USER_PROMPT_TEMPLATE = (
113
+ "Transcript (may be DE/FR/IT/EN):\n"
114
+ "```\n{transcript}\n```\n\n"
115
+ "Allowed Labels (canonical; use only these):\n"
116
+ "{allowed_labels_list}\n\n"
117
+ "Label Glossary (concise semantics):\n"
118
+ "{glossary}\n\n"
119
+ "Return STRICT JSON ONLY in this exact schema:\n"
120
+ '{\n "labels": ["<Label1>", "..."],\n'
121
+ ' "tasks": [{"label": "<Label1>", "explanation": "<why>", "evidence": "<quote>"}]\n}\n'
122
+ )
123
 
124
  # =========================
125
+ # Utilities
126
  # =========================
127
+ def _now_ms() -> int:
128
+ return int(time.time() * 1000)
129
+
130
+ def normalize_labels(labels: List[str]) -> List[str]:
131
+ return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
132
+
133
+ def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
134
+ return {lab.lower(): lab for lab in allowed}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ def robust_json_extract(text: str) -> Dict[str, Any]:
137
+ if not text:
138
+ return {"labels": [], "tasks": []}
139
+ start, end = text.find("{"), text.rfind("}")
140
+ candidate = text[start:end+1] if (start != -1 and end != -1 and end > start) else text
141
+ try:
142
+ return json.loads(candidate)
143
+ except Exception:
144
+ candidate = re.sub(r",\s*}", "}", candidate)
145
+ candidate = re.sub(r",\s*]", "]", candidate)
146
+ try:
147
+ return json.loads(candidate)
148
+ except Exception:
149
+ return {"labels": [], "tasks": []}
150
+
151
+ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
152
+ out = {"labels": [], "tasks": []}
153
+ allowed_map = canonicalize_map(allowed)
154
+ filt_labels = []
155
+ for l in pred.get("labels", []) or []:
156
+ k = str(l).strip().lower()
157
+ if k in allowed_map:
158
+ filt_labels.append(allowed_map[k])
159
+ filt_labels = normalize_labels(filt_labels)
160
+ filt_tasks = []
161
+ for t in pred.get("tasks", []) or []:
162
+ if not isinstance(t, dict):
163
+ continue
164
+ k = str(t.get("label", "")).strip().lower()
165
+ if k in allowed_map:
166
+ new_t = dict(t); new_t["label"] = allowed_map[k]
167
+ new_t = {
168
+ "label": new_t["label"],
169
+ "explanation": str(new_t.get("explanation", ""))[:300],
170
+ "evidence": str(new_t.get("evidence", ""))[:300],
171
+ }
172
+ filt_tasks.append(new_t)
173
+ merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
174
+ out["labels"] = merged
175
+ out["tasks"] = filt_tasks
176
+ return out
177
 
178
  # =========================
179
+ # Pre-processing
180
  # =========================
181
+ _DISCLAIMER_PATTERNS = [
182
+ r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
183
+ r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
184
+ r"(?is)^\s*this message \(including any attachments\).+?(?:\n{2,}|$)",
185
+ ]
186
+ _FOOTER_PATTERNS = [
187
+ r"(?is)\n+kind regards[^\n]*\n.*$", r"(?is)\n+best regards[^\n]*\n.*$",
188
+ r"(?is)\n+sent from my.*$", r"(?is)\n+ubs ag.*$",
189
+ ]
190
+ _TIMESTAMP_SPEAKER = [
191
+ r"\[\d{1,2}:\d{2}(:\d{2})?\]",
192
+ r"^\s*(advisor|client|client advisor)\s*:\s*",
193
+ r"^\s*(speaker\s*\d+)\s*:\s*",
194
+ ]
195
+
196
+ def clean_transcript(text: str) -> str:
197
+ if not text:
198
+ return text
199
+ s = text
200
+ # strip speaker/timestamps
201
+ lines = []
202
+ for ln in s.splitlines():
203
+ ln2 = ln
204
+ for pat in _TIMESTAMP_SPEAKER:
205
+ ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
206
+ lines.append(ln2)
207
+ s = "\n".join(lines)
208
+ # disclaimers (top)
209
+ for pat in _DISCLAIMER_PATTERNS:
210
+ s = re.sub(pat, "", s).strip()
211
+ # footers
212
+ for pat in _FOOTER_PATTERNS:
213
+ s = re.sub(pat, "", s)
214
+ # whitespace tidy
215
+ s = re.sub(r"[ \t]+", " ", s)
216
+ s = re.sub(r"\n{3,}", "\n\n", s).strip()
217
+ return s
218
+
219
+ def read_text_file_any(file_input) -> str:
220
+ if not file_input:
221
+ return ""
222
+ if isinstance(file_input, (str, Path)):
223
+ try:
224
+ return Path(file_input).read_text(encoding="utf-8", errors="ignore")
225
+ except Exception:
226
+ return ""
227
+ try:
228
+ data = file_input.read()
229
+ return data.decode("utf-8", errors="ignore")
230
+ except Exception:
231
+ return ""
232
+
233
+ def read_json_file_any(file_input) -> Optional[dict]:
234
+ if not file_input:
235
  return None
236
+ if isinstance(file_input, (str, Path)):
237
+ try:
238
+ return json.loads(Path(file_input).read_text(encoding="utf-8", errors="ignore"))
239
+ except Exception:
240
+ return None
241
  try:
242
+ return json.loads(file_input.read().decode("utf-8", errors="ignore"))
 
243
  except Exception:
244
  return None
245
 
246
+ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
247
+ toks = tokenizer(text, add_special_tokens=False)["input_ids"]
248
+ if len(toks) <= max_tokens:
249
+ return text
250
+ return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
251
 
252
+ # =========================
253
+ # HF model wrapper (robust loader + fast→slow tokenizer fallback)
254
+ # =========================
255
+ class ModelWrapper:
256
+ def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool):
257
+ self.repo_id = repo_id
258
+ self.hf_token = hf_token
259
+ self.load_in_4bit = load_in_4bit
260
+ self.use_sdpa = use_sdpa
261
+ self.tokenizer = None
262
+ self.model = None
263
+ self.load_path = "uninitialized"
264
+
265
+ def _load_tokenizer(self):
266
+ fast_err = None
267
+ tok = None
268
+ try:
269
+ tok = AutoTokenizer.from_pretrained(
270
+ self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
271
+ trust_remote_code=True, use_fast=True
272
+ )
273
+ except Exception as e:
274
+ fast_err = e
275
+ if tok is None:
276
+ tok = AutoTokenizer.from_pretrained(
277
+ self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
278
+ trust_remote_code=True, use_fast=False
279
+ )
280
+ if tok.pad_token is None and tok.eos_token:
281
+ tok.pad_token = tok.eos_token
282
+ return tok, fast_err
283
+
284
+ def load(self):
285
+ qcfg = None
286
+ if self.load_in_4bit and DEVICE == "cuda":
287
+ qcfg = BitsAndBytesConfig(
288
+ load_in_4bit=True,
289
+ bnb_4bit_quant_type="nf4",
290
+ bnb_4bit_compute_dtype=torch.float16,
291
+ bnb_4bit_use_double_quant=True,
292
+ )
293
 
294
+ tok, fast_err = self._load_tokenizer()
295
+
296
+ errors = []
297
+ for desc, kwargs in [
298
+ ("auto_device_no_lowcpu" + ("_sdpa" if (self.use_sdpa and DEVICE=="cuda") else ""),
299
+ dict(
300
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
301
+ device_map="auto" if DEVICE == "cuda" else None,
302
+ low_cpu_mem_usage=False,
303
+ quantization_config=qcfg,
304
+ trust_remote_code=True,
305
+ cache_dir=str(SPACE_CACHE),
306
+ attn_implementation=("sdpa" if (self.use_sdpa and DEVICE == "cuda") else None),
307
+ )),
308
+ ("auto_device_no_sdpa",
309
+ dict(
310
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
311
+ device_map="auto" if DEVICE == "cuda" else None,
312
+ low_cpu_mem_usage=False,
313
+ quantization_config=qcfg,
314
+ trust_remote_code=True,
315
+ cache_dir=str(SPACE_CACHE),
316
+ )),
317
+ ("cpu_then_to_cuda" if DEVICE == "cuda" else "cpu_only",
318
+ dict(
319
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
320
+ device_map=None,
321
+ low_cpu_mem_usage=False,
322
+ quantization_config=None if DEVICE != "cuda" else qcfg,
323
+ trust_remote_code=True,
324
+ cache_dir=str(SPACE_CACHE),
325
+ )),
326
+ ]:
327
+ try:
328
+ mdl = AutoModelForCausalLM.from_pretrained(self.repo_id, token=self.hf_token, **kwargs)
329
+ if desc.startswith("cpu_then_to_cuda") and DEVICE == "cuda":
330
+ mdl = mdl.to(torch.device("cuda"))
331
+ self.tokenizer = tok
332
+ self.model = mdl
333
+ self.load_path = desc + (" (fast tok)" if fast_err is None else " (slow tok)")
334
+ return
335
+ except Exception as e:
336
+ errors.append(f"{desc}: {e}")
337
+
338
+ extra = f"\nFast tokenizer error: {fast_err}" if fast_err else ""
339
+ raise RuntimeError("All load attempts failed:\n" + "\n".join(errors) + extra)
340
+
341
+ @torch.inference_mode()
342
+ def generate(self, system_prompt: str, user_prompt: str) -> str:
343
+ if hasattr(self.tokenizer, "apply_chat_template"):
344
+ messages = [
345
+ {"role": "system", "content": system_prompt},
346
+ {"role": "user", "content": user_prompt},
347
+ ]
348
+ input_ids = self.tokenizer.apply_chat_template(
349
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
350
+ )
351
+ input_ids = input_ids.to(self.model.device)
352
+ gen_kwargs = dict(
353
+ input_ids=input_ids,
354
+ generation_config=GEN_CONFIG,
355
+ eos_token_id=self.tokenizer.eos_token_id,
356
+ pad_token_id=self.tokenizer.pad_token_id,
357
+ )
358
+ else:
359
+ enc = self.tokenizer(
360
+ f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n",
361
+ return_tensors="pt"
362
+ ).to(self.model.device)
363
+ gen_kwargs = dict(
364
+ **enc,
365
+ generation_config=GEN_CONFIG,
366
+ eos_token_id=self.tokenizer.eos_token_id,
367
+ pad_token_id=self.tokenizer.pad_token_id,
368
+ )
369
 
370
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
371
+ out_ids = self.model.generate(**gen_kwargs)
372
+ return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
373
+
374
+ _MODEL_CACHE: Dict[str, ModelWrapper] = {}
375
+ def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool) -> ModelWrapper:
376
+ key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}::{'sdpa' if use_sdpa else 'nosdpa'}"
377
+ if key not in _MODEL_CACHE:
378
+ m = ModelWrapper(repo_id, hf_token, load_in_4bit, use_sdpa)
379
+ m.load()
380
+ _MODEL_CACHE[key] = m
381
+ return _MODEL_CACHE[key]
382
 
383
  # =========================
384
+ # Evaluation (official weighted score)
385
  # =========================
386
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
387
+ ALLOWED_LABELS = OFFICIAL_LABELS
 
388
  LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
389
+
390
+ def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
391
+ if not isinstance(sample_labels, list):
392
+ raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
393
+ seen, uniq = set(), []
394
+ for label in sample_labels:
395
+ if not isinstance(label, str):
396
+ raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
397
+ if label in seen:
398
+ raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
399
+ if label not in ALLOWED_LABELS:
400
+ raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
401
+ seen.add(label); uniq.append(label)
402
+ return uniq
403
 
404
  if len(y_true) != len(y_pred):
405
  raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
406
 
407
  n_samples = len(y_true)
408
+ n_labels = len(OFFICIAL_LABELS)
 
409
  y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
410
  y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
411
 
412
+ for i, sample_labels in enumerate(y_true):
413
+ for label in _process_sample_labels(sample_labels, f"y_true[{i}]"):
414
+ y_true_binary[i, LABEL_TO_IDX[label]] = 1
415
+
416
+ for i, sample_labels in enumerate(y_pred):
417
+ for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
418
+ y_pred_binary[i, LABEL_TO_IDX[label]] = 1
419
+
420
+ fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
421
+ fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
422
+ weighted = 2.0 * fn + 1.0 * fp
423
+ max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
424
+ per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
425
+ return float(max(0.0, min(1.0, np.mean(per_sample))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
  # =========================
428
+ # Multilingual regex fallback
429
  # =========================
430
+ def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
431
+ low = text.lower()
432
+ labels, tasks = [], []
433
+ for lab in allowed:
434
+ for pat in cues.get(lab, []):
435
+ m = re.search(pat, low)
436
+ if m:
437
+ i = m.start()
438
+ start = max(0, i - 60); end = min(len(text), i + len(m.group(0)) + 60)
439
+ if lab not in labels:
440
+ labels.append(lab)
441
+ tasks.append({
442
+ "label": lab,
443
+ "explanation": "Rule hit (multilingual fallback)",
444
+ "evidence": text[start:end].strip()
445
+ })
446
+ break
447
+ return {"labels": normalize_labels(labels), "tasks": tasks}
448
+
449
+ # =========================
450
+ # Inference helpers
451
+ # =========================
452
+ def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
453
+ return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
454
+
455
+ def warmup_model(model_repo: str, use_4bit: bool, use_sdpa: bool, hf_token: str) -> str:
456
+ t0 = _now_ms()
457
+ try:
458
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
459
+ _ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}')
460
+ return f"Warm-up complete in {_now_ms() - t0} ms. Load path: {model.load_path}"
461
+ except Exception as e:
462
+ return f"Warm-up failed: {e}"
463
+
464
+ def run_single(
465
+ transcript_text: str,
466
+ transcript_file,
467
+ gt_json_text: str,
468
+ gt_json_file,
469
+ use_cleaning: bool,
470
+ use_fallback: bool,
471
+ allowed_labels_text: str,
472
+ sys_instructions_text: str,
473
+ glossary_json_text: str,
474
+ fallback_json_text: str,
475
+ model_repo: str,
476
+ use_4bit: bool,
477
+ use_sdpa: bool,
478
+ max_input_tokens: int,
479
+ hf_token: str,
480
+ ) -> Tuple[str, str, str, str, str, str, str, str, str]:
481
+
482
+ t0 = _now_ms()
483
+
484
+ # Transcript
485
+ raw_text = ""
486
+ if transcript_file:
487
+ raw_text = read_text_file_any(transcript_file)
488
+ raw_text = (raw_text or transcript_text or "").strip()
489
+ if not raw_text:
490
+ return "", "", "No transcript provided.", "", "", "", "", "", ""
491
+
492
+ text = clean_transcript(raw_text) if use_cleaning else raw_text
493
+
494
+ # Allowed labels
495
+ user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
496
+ allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
497
+
498
+ # Editable configs
499
+ try:
500
+ sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS
501
+ except Exception:
502
+ sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
503
+
504
+ try:
505
+ label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
506
+ except Exception:
507
+ label_glossary = DEFAULT_LABEL_GLOSSARY
508
+
509
+ try:
510
+ fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
511
+ except Exception:
512
+ fallback_cues = DEFAULT_FALLBACK_CUES
513
 
514
+ # Model
515
  try:
516
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
517
  except Exception as e:
518
+ return "", "", f"Model load failed: {e}", "", "", "", "", "", ""
519
+
520
+ # Truncate
521
+ trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
522
+
523
+ # Build prompt
524
+ glossary_str = build_glossary_str(label_glossary, allowed)
525
+ allowed_list_str = "\n".join(f"- {l}" for l in allowed)
526
+ user_prompt = USER_PROMPT_TEMPLATE.format(
527
+ transcript=trunc,
528
+ allowed_labels_list=allowed_list_str,
529
+ glossary=glossary_str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  )
 
531
 
532
+ # Token info + prompt preview
533
+ transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"])
534
+ prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"])
535
+ token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens} | Load path: {model.load_path}"
536
+ prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```"
537
+
538
+ # Generate
539
+ t1 = _now_ms()
540
  try:
541
+ out = model.generate(sys_instructions, user_prompt)
 
 
 
 
 
 
 
 
 
 
542
  except Exception as e:
543
+ return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, ""
544
+ t2 = _now_ms()
545
+
546
+ parsed = robust_json_extract(out)
547
+ filtered = restrict_to_allowed(parsed, allowed)
548
+
549
+ # Fallback merge for recall
550
+ if use_fallback:
551
+ fb = multilingual_fallback(trunc, allowed, fallback_cues)
552
+ if fb["labels"]:
553
+ merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
554
+ existing = {tt.get("label") for tt in filtered.get("tasks", [])}
555
+ merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
556
+ filtered = {"labels": merged_labels, "tasks": merged_tasks}
557
+
558
+ # Diagnostics
559
+ diag = "\n".join([
560
+ f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
561
+ f"Model: {model_repo}",
562
+ f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
563
+ f"Fallback rules: {'Yes' if use_fallback else 'No'}",
564
+ f"SDPA attention: {'Yes' if use_sdpa else 'No'}",
565
+ f"Tokens (input limit): {max_input_tokens}",
566
+ f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
567
+ f"Allowed labels: {', '.join(allowed)}",
 
 
 
 
 
 
568
  ])
 
569
 
570
+ # Summaries
571
+ labs = filtered.get("labels", [])
572
+ tasks = filtered.get("tasks", [])
573
+ summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
574
+ if tasks:
575
+ summary += "\n\nTasks:\n" + "\n".join(
576
+ f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:140]}{'…' if len(t.get('evidence',''))>140 else ''}"
577
+ for t in tasks
578
+ )
579
+ else:
580
+ summary += "\n\nTasks: (none)"
581
+ json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
582
+
583
+ # Single-file metrics if GT provided
584
+ metrics = ""
585
+ if gt_json_file or (gt_json_text and gt_json_text.strip()):
586
+ truth_obj = None
587
+ if gt_json_file:
588
+ truth_obj = read_json_file_any(gt_json_file)
589
+ if (not truth_obj) and gt_json_text:
590
+ try:
591
+ truth_obj = json.loads(gt_json_text)
592
+ except Exception:
593
+ pass
594
+ if isinstance(truth_obj, dict) and isinstance(truth_obj.get("labels"), list):
595
+ true_labels = [x for x in truth_obj["labels"] if x in OFFICIAL_LABELS]
596
+ pred_labels = labs
597
+ try:
598
+ score = evaluate_predictions([true_labels], [pred_labels])
599
+ tp = len(set(true_labels) & set(pred_labels))
600
+ fp = len(set(pred_labels) - set(true_labels))
601
+ fn = len(set(true_labels) - set(pred_labels))
602
+ recall = tp / (tp + fn) if (tp + fn) else 1.0
603
+ precision = tp / (tp + fp) if (tp + fp) else 1.0
604
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
605
+ metrics = (
606
+ f"Weighted score: {score:.3f}\n"
607
+ f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}\n"
608
+ f"TP={tp} FP={fp} FN={fn}\n"
609
+ f"Truth: {', '.join(true_labels)}"
610
+ )
611
+ except Exception as e:
612
+ metrics = f"Scoring error: {e}"
613
+ else:
614
+ metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
615
+
616
+ # Previews
617
+ context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in label_glossary.items() if k in allowed)
618
+ instructions_preview = "```\n" + sys_instructions + "\n```"
619
+
620
+ return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
621
 
622
  # =========================
623
+ # Batch mode
624
  # =========================
625
+ def read_zip_from_path(path: str, exdir: Path) -> List[Path]:
626
+ exdir.mkdir(parents=True, exist_ok=True)
627
+ with open(path, "rb") as f:
628
+ data = f.read()
629
+ with zipfile.ZipFile(io.BytesIO(data)) as zf:
630
+ zf.extractall(exdir)
631
+ return [p for p in exdir.rglob("*") if p.is_file()]
632
+
633
+ def run_batch(
634
+ zip_path,
635
+ use_cleaning: bool,
636
+ use_fallback: bool,
637
+ sys_instructions_text: str,
638
+ glossary_json_text: str,
639
+ fallback_json_text: str,
640
+ model_repo: str,
641
+ use_4bit: bool,
642
+ use_sdpa: bool,
643
+ max_input_tokens: int,
644
+ hf_token: str,
645
+ limit_files: int,
646
+ ) -> Tuple[str, str, pd.DataFrame, str]:
647
+
648
+ if not zip_path:
649
+ return ("No ZIP provided.", "", pd.DataFrame(), "")
650
 
651
  try:
652
+ sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip() or DEFAULT_SYSTEM_INSTRUCTIONS
653
+ except Exception:
654
+ sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
655
+ try:
656
+ label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
657
+ except Exception:
658
+ label_glossary = DEFAULT_LABEL_GLOSSARY
659
+ try:
660
+ fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
661
+ except Exception:
662
+ fallback_cues = DEFAULT_FALLBACK_CUES
663
+
664
+ # Workspace
665
+ work = Path("/tmp/batch")
666
+ if work.exists():
667
+ for p in sorted(work.rglob("*"), reverse=True):
668
+ try: p.unlink()
669
+ except Exception: pass
670
+ try: work.rmdir()
671
+ except Exception: pass
672
+ work.mkdir(parents=True, exist_ok=True)
673
+
674
+ files = read_zip_from_path(zip_path, work)
675
+ txts: Dict[str, Path] = {}
676
+ gts: Dict[str, Path] = {}
677
+ for p in files:
678
+ if p.suffix.lower() == ".txt":
679
+ txts[p.stem] = p
680
+ elif p.suffix.lower() == ".json":
681
+ gts[p.stem] = p
682
+
683
+ stems = sorted(txts.keys())
684
+ if limit_files > 0:
685
+ stems = stems[:limit_files]
686
+ if not stems:
687
+ return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
688
+
689
+ # Model
690
+ try:
691
+ model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
692
  except Exception as e:
693
+ return (f"Model load failed: {e}", "", pd.DataFrame(), "")
694
+
695
+ allowed = OFFICIAL_LABELS[:]
696
+ glossary_str = build_glossary_str(label_glossary, allowed)
697
+ allowed_list_str = "\n".join(f"- {l}" for l in allowed)
698
+
699
+ y_true, y_pred = [], []
700
+ rows = []
701
+ t_start = _now_ms()
702
+
703
+ for stem in stems:
704
+ raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
705
+ text = clean_transcript(raw) if use_cleaning else raw
706
+
707
+ trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
708
+ user_prompt = USER_PROMPT_TEMPLATE.format(
709
+ transcript=trunc,
710
+ allowed_labels_list=allowed_list_str,
711
+ glossary=glossary_str,
712
+ )
713
+
714
+ t0 = _now_ms()
715
+ out = model.generate(sys_instructions, user_prompt)
716
+ t1 = _now_ms()
717
+
718
+ parsed = robust_json_extract(out)
719
+ filtered = restrict_to_allowed(parsed, allowed)
720
+
721
+ if use_fallback:
722
+ fb = multilingual_fallback(trunc, allowed, fallback_cues)
723
+ if fb["labels"]:
724
+ merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
725
+ existing = {tt.get("label") for tt in filtered.get("tasks", [])}
726
+ merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
727
+ filtered = {"labels": merged_labels, "tasks": merged_tasks}
728
+
729
+ pred_labels = filtered.get("labels", [])
730
+ y_pred.append(pred_labels)
731
+
732
+ gt_labels = []
733
+ if stem in gts:
734
+ try:
735
+ gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
736
+ if isinstance(gt_obj, dict) and isinstance(gt_obj.get("labels"), list):
737
+ gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
738
+ except Exception:
739
+ pass
740
+ y_true.append(gt_labels)
741
+
742
+ gt_set, pr_set = set(gt_labels), set(pred_labels)
743
+ tp = sorted(gt_set & pr_set)
744
+ fp = sorted(pr_set - gt_set)
745
+ fn = sorted(gt_set - pr_set)
746
+
747
+ rows.append({
748
+ "file": stem,
749
+ "true_labels": ", ".join(gt_labels),
750
+ "pred_labels": ", ".join(pred_labels),
751
+ "TP": len(tp), "FP": len(fp), "FN": len(fn),
752
+ "gen_ms": t1 - t0
753
+ })
754
+
755
+ have_truth = any(len(v) > 0 for v in y_true)
756
+ score = evaluate_predictions(y_true, y_pred) if have_truth else None
757
+
758
+ df = pd.DataFrame(rows).sort_values(["FN", "FP", "file"])
759
+ diag = [
760
+ f"Processed files: {len(stems)}",
761
+ f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
762
+ f"Model: {model_repo}",
763
+ f"Fallback rules: {'Yes' if use_fallback else 'No'}",
764
+ f"SDPA attention: {'Yes' if use_sdpa else 'No'}",
765
+ f"Tokens (input limit): ≤ {max_input_tokens}",
766
+ f"Batch time: {_now_ms()-t_start} ms",
767
+ ]
768
+ if have_truth and score is not None:
769
+ total_tp = int(df["TP"].sum())
770
+ total_fp = int(df["FP"].sum())
771
+ total_fn = int(df["FN"].sum())
772
+ recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 1.0
773
+ precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 1.0
774
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
775
+ diag += [
776
+ f"Official weighted score (0–1): {score:.3f}",
777
+ f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}",
778
+ f"Total TP={total_tp} FP={total_fp} FN={total_fn}",
779
+ ]
780
+ diag_str = "\n".join(diag)
781
+
782
+ out_csv = Path("/tmp/batch_results.csv")
783
+ df.to_csv(out_csv, index=False, encoding="utf-8")
784
+ return ("Batch done.", diag_str, df, str(out_csv))
785
 
786
  # =========================
787
+ # UI
788
  # =========================
789
+ MODEL_CHOICES = [
790
+ "swiss-ai/Apertus-8B-Instruct-2509", # multilingual
791
+ "meta-llama/Meta-Llama-3-8B-Instruct",
792
+ "mistralai/Mistral-7B-Instruct-v0.3",
793
+ ]
794
+
795
+ # White, modern UI (no purple)
796
+ custom_css = """
797
+ :root { --radius: 14px; }
798
+ .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
799
+ .card { border: 1px solid #e5e7eb; border-radius: var(--radius); padding: 14px 16px; background: #ffffff; box-shadow: 0 1px 2px rgba(0,0,0,.03); }
800
+ .header { font-weight: 700; font-size: 22px; margin-bottom: 4px; color: #0f172a; }
801
+ .subtle { color: #475569; font-size: 14px; margin-bottom: 12px; }
802
+ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 10px 0 16px; }
803
+ .gr-button { border-radius: 12px !important; }
804
+ a, .prose a { color: #0ea5e9; }
805
  """
806
 
807
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
808
+ gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>")
809
+ gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN) with compact prompts. Optional rule fallback ensures recall. Batch evaluation & scoring included.</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
 
811
+ with gr.Tab("Single transcript"):
812
  with gr.Row():
813
+ with gr.Column(scale=3):
814
+ gr.Markdown("<div class='card'><div class='header'>Transcript</div>")
815
+ file = gr.File(
816
+ label="Drag & drop transcript (.txt / .md / .json)",
817
+ file_types=[".txt", ".md", ".json"],
818
+ type="filepath",
819
  )
820
+ text = gr.Textbox(label="Or paste transcript", lines=10, placeholder="Paste transcript in DE/FR/IT/EN…")
821
+ gr.Markdown("<hr class='sep'/>")
822
+
823
+ gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>")
824
+ gt_file = gr.File(
825
+ label="Upload ground truth JSON (expects {'labels': [...]})",
826
+ file_types=[".json"],
827
+ type="filepath",
828
  )
829
+ gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
830
+ gr.Markdown("</div>") # close card
831
+
832
+ gr.Markdown("<div class='card'><div class='header'>Processing options</div>")
833
+ use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)", value=True)
834
+ use_fallback = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
835
+ gr.Markdown("</div>")
836
+
837
+ gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>")
838
+ labels_text = gr.Textbox(label="Allowed Labels (one per line)", value=OFFICIAL_LABELS_TEXT, lines=8)
839
+ reset_btn = gr.Button("Reset to official labels")
840
+ gr.Markdown("</div>")
841
+
842
+ gr.Markdown("<div class='card'><div class='header'>Editable instructions & context</div>")
843
+ sys_instr_tb = gr.Textbox(label="System Instructions (editable)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=5)
844
+ glossary_tb = gr.Code(label="Label Glossary (JSON; editable)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
845
+ fallback_tb = gr.Code(label="Fallback Cues (Multilingual, JSON; editable)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
846
+ gr.Markdown("</div>")
847
+
848
+ with gr.Column(scale=2):
849
+ gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
850
+ repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
851
+ use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
852
+ use_sdpa = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
853
+ max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
854
+ hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
855
+ warm_btn = gr.Button("Warm up model (load & compile kernels)")
856
+ run_btn = gr.Button("Run Extraction", variant="primary")
857
+ gr.Markdown("</div>")
858
+
859
+ gr.Markdown("<div class='card'><div class='header'>Outputs</div>")
860
+ summary = gr.Textbox(label="Summary", lines=12)
861
+ json_out = gr.Code(label="Strict JSON Output", language="json")
862
+ diag = gr.Textbox(label="Diagnostics", lines=10)
863
+ raw = gr.Textbox(label="Raw Model Output", lines=8)
864
+ prompt_preview = gr.Code(label="Prompt preview (user prompt sent)", language="markdown")
865
+ token_info = gr.Textbox(label="Token counts (transcript / prompt / load path)", lines=2)
866
+ gr.Markdown("</div>")
867
 
868
+ with gr.Row():
869
+ with gr.Column():
870
+ with gr.Accordion("Instructions used (system prompt)", open=False):
871
+ instr_md = gr.Markdown("```\n" + DEFAULT_SYSTEM_INSTRUCTIONS + "\n```")
872
+ with gr.Column():
873
+ with gr.Accordion("Context used (glossary)", open=True):
874
+ context_md = gr.Markdown("")
875
+
876
+ # Reset labels
877
+ reset_btn.click(fn=lambda: OFFICIAL_LABELS_TEXT, inputs=None, outputs=labels_text)
878
+ # Warm-up
879
+ warm_btn.click(fn=warmup_model, inputs=[repo, use_4bit, use_sdpa, hf_token], outputs=diag)
880
+
881
+ def _pack_context_md(glossary_json, allowed_text):
882
+ try:
883
+ glossary = json.loads(glossary_json) if glossary_json else DEFAULT_LABEL_GLOSSARY
884
+ except Exception:
885
+ glossary = DEFAULT_LABEL_GLOSSARY
886
+ allowed_list = [ln.strip() for ln in (allowed_text or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
887
+ return "### Label Glossary (used)\n" + "\n".join(f"- {k}: {glossary.get(k,'')}" for k in allowed_list)
888
+
889
+ context_md.value = _pack_context_md(json.dumps(DEFAULT_LABEL_GLOSSARY), OFFICIAL_LABELS_TEXT)
890
+
891
+ # Run single
892
+ run_btn.click(
893
+ fn=run_single,
894
+ inputs=[
895
+ text, file, gt_text, gt_file, use_cleaning, use_fallback,
896
+ labels_text, sys_instr_tb, glossary_tb, fallback_tb,
897
+ repo, use_4bit, use_sdpa, max_tokens, hf_token
898
+ ],
899
+ outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
900
+ )
901
 
902
+ with gr.Tab("Batch evaluation"):
903
+ with gr.Row():
904
+ with gr.Column(scale=3):
905
+ gr.Markdown("<div class='card'><div class='header'>ZIP input</div>")
906
+ zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
907
+ use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
908
+ use_fallback_b = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
909
+ gr.Markdown("</div>")
910
+ with gr.Column(scale=2):
911
+ gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
912
+ repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
913
+ use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
914
+ use_sdpa_b = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
915
+ max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
916
+ hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
917
+ sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
918
+ glossary_tb_b = gr.Code(label="Label Glossary (JSON; editable for batch)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
919
+ fallback_tb_b = gr.Code(label="Fallback Cues (Multilingual, JSON; editable for batch)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
920
+ limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
921
+ run_batch_btn = gr.Button("Run Batch", variant="primary")
922
+ gr.Markdown("</div>")
923
 
924
+ with gr.Row():
925
+ gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>")
926
+ status = gr.Textbox(label="Status", lines=1)
927
+ diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
928
+ df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
929
+ csv_out = gr.File(label="Download CSV", interactive=False)
930
+ gr.Markdown("</div>")
931
+
932
+ run_batch_btn.click(
933
+ fn=run_batch,
934
+ inputs=[
935
+ zip_in, use_cleaning_b, use_fallback_b,
936
+ sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
937
+ repo_b, use_4bit_b, use_sdpa_b, max_tokens_b, hf_token_b, limit_files
938
+ ],
939
+ outputs=[status, diag_b, df_out, csv_out],
940
+ )
941
 
942
  if __name__ == "__main__":
943
  demo.launch()