RishiRP commited on
Commit
e372e2c
·
verified ·
1 Parent(s): ae411c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -179
app.py CHANGED
@@ -1,99 +1,113 @@
1
  # app.py
2
- # From Talk to Task — Batch & Single Task Extraction
3
- # Works on CPU / GPU / ZeroGPU. Uses a writable HF cache path (no /data).
4
- # If you want to use gated models (e.g., mistralai/Mistral-7B-Instruct-v0.2),
5
- # accept the license on HF and set HF_TOKEN in Space → Settings → Secrets.
6
 
7
  import os
8
  import io
9
  import re
10
- import sys
11
  import time
12
  import json
13
  import zipfile
14
  from pathlib import Path
15
- from typing import List, Dict, Tuple, Optional
16
 
17
  import gradio as gr
18
 
19
- # ====== Robust, writable HF cache ======
20
- # Avoid /data (read-only in Spaces). Prefer $HOME or /tmp.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  HOME = Path(os.environ.get("HOME", "/home/user"))
22
  CACHE_DIR = HOME / ".cache" / "huggingface"
23
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
24
  os.environ.setdefault("HF_HOME", str(CACHE_DIR))
25
- # NOTE: TRANSFORMERS_CACHE is deprecated; HF_HOME is enough.
26
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads when available
 
27
 
28
- HF_TOKEN = os.environ.get("HF_TOKEN", "").strip() or None
29
 
30
- # ====== Transformers safe import ======
31
  try:
32
  import torch
33
- from transformers import (
34
- AutoTokenizer,
35
- AutoModelForCausalLM,
36
- BitsAndBytesConfig,
37
- )
38
  except Exception as e:
39
  raise RuntimeError(
40
- "Failed to import transformers/torch. "
41
- "Make sure requirements.txt includes: transformers>=4.41, torch, accelerate"
42
  ) from e
43
 
44
- DTYPE_FALLBACK = torch.float32
45
- if torch.cuda.is_available():
46
- DTYPE_FALLBACK = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
47
-
48
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
49
 
50
- # ====== ZeroGPU (optional) ======
51
  try:
52
  import spaces # noqa: F401
53
  ON_ZERO_GPU = True
54
  except Exception:
55
  ON_ZERO_GPU = False
56
 
57
- # ====== UI presets ======
58
- OPEN_MODEL_PRESETS = [
59
- "HuggingFaceH4/zephyr-7b-beta",
60
- "Qwen/Qwen2.5-7B-Instruct",
61
- "tiiuae/falcon-7b-instruct",
62
- ]
63
 
64
- PINNED_REVISIONS = {
65
- "HuggingFaceH4/zephyr-7b-beta": None,
66
- "Qwen/Qwen2.5-7B-Instruct": None,
67
- "tiiuae/falcon-7b-instruct": None,
68
- # "mistralai/Mistral-7B-Instruct-v0.2": None, # gated — use only if token + license ok
69
- }
70
-
71
- SYSTEM_INSTRUCTIONS = (
72
- "You are a task extraction assistant. Always output valid JSON with a field "
73
- '"labels" (list of strings). Use only from this set: '
74
- '["plan_contact","schedule_meeting","update_contact_info_non_postal",'
75
- '"update_contact_info_postal_address","update_kyc_activity","update_kyc_origin_of_assets",'
76
- '"update_kyc_purpose_of_businessrelation","update_kyc_total_assets"]. '
77
- "Return JSON only."
78
- )
79
-
80
- CONTEXT_GUIDE = """\
81
- - plan_contact: conversation without a concrete meeting (no date/time)
82
- - schedule_meeting: explicit date/time/modality confirmation
83
- - update_contact_info_non_postal: changes to email/phone
84
- - update_contact_info_postal_address: changes to mailing address
85
- - update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)
86
- """
87
-
88
- # ====== Utility ======
89
- def _json_only(text: str) -> str:
90
- text = text.strip()
91
- if text.startswith("{") and text.endswith("}"):
92
- return text
93
- m = re.search(r"\{.*\}", text, re.DOTALL)
94
  return m.group(0) if m else '{"labels": []}'
95
 
96
- def safe_json_loads(s: str) -> dict:
97
  try:
98
  return json.loads(s)
99
  except Exception:
@@ -101,78 +115,53 @@ def safe_json_loads(s: str) -> dict:
101
 
102
  def build_prompt(system: str, context: str, transcript: str) -> str:
103
  return (
104
- f"### System:\n{system}\n\n"
105
- f"### Context:\n{context}\n\n"
106
- f"### Transcript:\n{transcript}\n\n"
107
- "### Output:\nReturn JSON only."
108
  )
109
 
110
- # ====== Model wrapper ======
 
111
  class HFModel:
112
  def __init__(
113
  self,
114
  repo_id: str,
115
- revision: Optional[str] = None,
116
- load_in_4bit: bool = False,
117
- trust_remote_code: bool = True,
118
- dtype: Optional[torch.dtype] = None,
119
- token: Optional[str] = None,
120
- ) -> None:
121
  self.repo_id = repo_id
122
  self.revision = revision or "main"
123
- self.trust_remote_code = trust_remote_code
124
  self.token = token
125
- self.dtype = dtype or DTYPE_FALLBACK
126
  self.load_in_4bit = load_in_4bit and (DEVICE == "cuda")
 
127
  self.tokenizer = None
128
  self.model = None
129
 
130
  def load(self):
131
- quant_cfg = None
132
- if self.load_in_4bit:
133
- quant_cfg = BitsAndBytesConfig(load_in_4bit=True)
134
- try:
135
- self.tokenizer = AutoTokenizer.from_pretrained(
136
- self.repo_id,
137
- revision=self.revision,
138
- token=self.token,
139
- cache_dir=str(CACHE_DIR),
140
- trust_remote_code=self.trust_remote_code,
141
- use_fast=True,
142
- )
143
- except Exception as e:
144
- raise RuntimeError(
145
- f"Failed to load tokenizer for {self.repo_id} "
146
- "(If gated, accept license and set HF_TOKEN in Space → Settings → Secrets)."
147
- ) from e
148
-
149
- try:
150
- self.model = AutoModelForCausalLM.from_pretrained(
151
- self.repo_id,
152
- revision=self.revision,
153
- token=self.token,
154
- cache_dir=str(CACHE_DIR),
155
- trust_remote_code=self.trust_remote_code,
156
- torch_dtype=self.dtype,
157
- device_map="auto" if DEVICE == "cuda" else None,
158
- quantization_config=quant_cfg,
159
- low_cpu_mem_usage=True,
160
- )
161
- if DEVICE == "cpu":
162
- self.model = self.model.to(DEVICE)
163
- except Exception as e:
164
- raise RuntimeError(
165
- f"Failed to load model weights for {self.repo_id}. "
166
- "Check license, token, and hardware availability."
167
- ) from e
168
 
169
  @torch.inference_mode()
170
- def generate(self, prompt: str, max_new_tokens: int = 256, temperature: float = 0.1) -> str:
171
  tok = self.tokenizer
172
  mdl = self.model
173
  if tok.pad_token is None:
174
  tok.pad_token = tok.eos_token
175
-
176
  inputs = tok(prompt, return_tensors="pt").to(mdl.device)
177
  out = mdl.generate(
178
  **inputs,
@@ -183,28 +172,29 @@ class HFModel:
183
  pad_token_id=tok.eos_token_id,
184
  eos_token_id=tok.eos_token_id,
185
  )
186
- text = tok.decode(out[0], skip_special_tokens=True)
187
- gen = text[len(prompt):].strip() if text.startswith(prompt) else text
188
- return _json_only(gen)
189
 
190
- # ====== Model cache (per Space worker) ======
191
  _MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
192
 
193
  def get_model(repo_id: str, revision: Optional[str], load_in_4bit: bool) -> HFModel:
194
  key = (repo_id, revision, load_in_4bit)
195
  if key in _MODEL_CACHE:
196
  return _MODEL_CACHE[key]
197
- model = HFModel(
198
- repo_id=repo_id,
199
- revision=revision,
200
- load_in_4bit=load_in_4bit,
201
- token=HF_TOKEN,
202
- )
203
- model.load()
204
- _MODEL_CACHE[key] = model
205
- return model
 
 
 
206
 
207
- # ====== Single transcript inference ======
208
  def run_single(
209
  model_choice: str,
210
  custom_repo_id: str,
@@ -217,18 +207,18 @@ def run_single(
217
  add_header: bool,
218
  strip_smalltalk: bool,
219
  load_in_4bit: bool,
220
- ) -> Tuple[str, str, str, str]:
221
  debug = []
222
  t0 = time.perf_counter()
223
 
224
- repo = (custom_repo_id or model_choice).strip()
225
  rev = PINNED_REVISIONS.get(repo, None)
226
- debug.append(f"Repo: {repo} | Revision: {rev or 'main'} | 4bit: {load_in_4bit} | Device: {DEVICE}")
227
 
228
  if preprocess:
229
  lines = [ln.rstrip() for ln in transcript.splitlines()]
230
  if strip_smalltalk:
231
- lines = [ln for ln in lines if not re.search(r"\b(thanks?|bye|ok(ay)?)\b", ln, re.I)]
232
  transcript = "\n".join(lines[-32768:])
233
  if add_header:
234
  transcript = f"[EMAIL/MESSAGE SIGNAL]\n{transcript}"
@@ -242,31 +232,30 @@ def run_single(
242
 
243
  try:
244
  model = get_model(repo, rev, load_in_4bit)
245
- raw = model.generate(prompt, max_new_tokens=256, temperature=0.1)
246
- data = safe_json_loads(raw)
247
  out_json = json.dumps(data, ensure_ascii=False)
248
- debug.append(f"Generation OK in {time.perf_counter()-t0:.2f}s")
249
  return repo, (rev or "main"), out_json, "\n".join(debug)
250
  except Exception as e:
251
  debug.append(f"ERROR: {e}")
252
  return repo, (rev or "main"), json.dumps({"labels": []}), "\n".join(debug)
253
 
254
- # ====== Batch (ZIP of many .txt files) ======
255
  def run_batch(
256
  model_choice: str,
257
  custom_repo_id: str,
258
  system: str,
259
  context: str,
260
- zip_file: Optional[io.BytesIO],
261
  soft_token_cap: int,
262
  preprocess: bool,
263
  lines_window: int,
264
  add_header: bool,
265
  strip_smalltalk: bool,
266
  load_in_4bit: bool,
267
- ) -> Tuple[str, str, str, str]:
268
  debug = []
269
- repo = (custom_repo_id or model_choice).strip()
270
  rev = PINNED_REVISIONS.get(repo, None)
271
 
272
  if not zip_file:
@@ -275,12 +264,13 @@ def run_batch(
275
  try:
276
  z = zipfile.ZipFile(zip_file)
277
  names = [n for n in z.namelist() if n.lower().endswith(".txt")]
278
- debug.append(f"Files detected: {len(names)}")
279
  except Exception as e:
280
  return repo, (rev or "main"), "filename,labels\n", f"Bad ZIP: {e}"
281
 
 
282
  try:
283
- model = get_model(repo, rev, load_in_4bit)
284
  except Exception as e:
285
  return repo, (rev or "main"), "filename,labels\n", f"Model load error: {e}"
286
 
@@ -288,46 +278,43 @@ def run_batch(
288
  for name in names:
289
  try:
290
  txt = z.read(name).decode("utf-8", errors="replace")
291
- _, _, labels_json, _ = run_single(
292
  model_choice, custom_repo_id, system, context, txt,
293
- soft_token_cap, preprocess, lines_window, add_header,
294
- strip_smalltalk, load_in_4bit
295
  )
296
- labels = safe_json_loads(labels_json).get("labels", [])
297
  rows.append(f"{name},{json.dumps(labels, ensure_ascii=False)}")
298
  except Exception as e:
299
  rows.append(f"{name},[] # error: {e}")
300
 
301
  return repo, (rev or "main"), "\n".join(rows), "\n".join(debug)
302
 
303
- # ====== Gradio UI ======
304
- with gr.Blocks(title="From Talk to Task — Batch & Single Task Extraction") as demo:
 
305
  gr.Markdown(
306
- """
307
- # From Talk to Task — Batch & Single Task Extraction
308
 
309
- **Tip:** Use **open models** first (no gating). If you pick a gated model, make sure
310
- you have accepted its license _and_ set `HF_TOKEN` in **Settings → Secrets**.
311
 
312
- **Pinned revisions:** {}
313
- """.format(
314
- ", ".join([f"{k}@{v or 'main'}" for k, v in PINNED_REVISIONS.items()])
315
- )
316
  )
317
 
318
  with gr.Row():
319
  model_choice = gr.Dropdown(
320
- OPEN_MODEL_PRESETS,
321
- label="Model (Open presets no gating)",
322
- value=OPEN_MODEL_PRESETS[0],
323
  )
324
- custom_repo_id = gr.Textbox(
325
  label="Custom model repo id (overrides preset)",
326
- placeholder="e.g. mistralai/Mistral-7B-Instruct-v0.2 (requires license + HF_TOKEN)"
327
  )
328
 
329
- system = gr.Textbox(label="Instructions (System)", value=SYSTEM_INSTRUCTIONS, lines=5)
330
- context = gr.Textbox(label="Context (User prefix before transcript)", value=CONTEXT_GUIDE, lines=6)
331
 
332
  with gr.Row():
333
  soft_cap = gr.Slider(1024, 32768, value=8192, step=1, label="Soft token cap")
@@ -339,22 +326,21 @@ with gr.Blocks(title="From Talk to Task — Batch & Single Task Extraction") as
339
  load_4bit = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
340
 
341
  with gr.Tabs():
342
- with gr.Tab("Single Transcript (default)"):
343
- transcript = gr.Textbox(label="Paste transcript text", lines=12, placeholder="Paste your transcript here...")
344
  run_btn = gr.Button("Run (Single)", variant="primary")
345
  repo_used = gr.Textbox(label="Repo used", interactive=False)
346
  rev_used = gr.Textbox(label="Revision", interactive=False)
347
  json_out = gr.Code(label="JSON Output", language="json")
348
  debug_out = gr.Textbox(label="Diagnostics", lines=6)
349
 
350
- def _run_single(*args):
351
- r, v, j, d = run_single(*args)
352
- return r, v, j, d
353
 
354
  run_btn.click(
355
- _run_single,
356
  inputs=[
357
- model_choice, custom_repo_id, system, context, transcript,
358
  soft_cap, preprocess, lines_window, add_header, strip_smalltalk, load_4bit
359
  ],
360
  outputs=[repo_used, rev_used, json_out, debug_out],
@@ -365,29 +351,25 @@ with gr.Blocks(title="From Talk to Task — Batch & Single Task Extraction") as
365
  run_batch_btn = gr.Button("Run (Batch)", variant="primary")
366
  repo_used_b = gr.Textbox(label="Repo used", interactive=False)
367
  rev_used_b = gr.Textbox(label="Revision", interactive=False)
368
- # FIX: use Textbox for CSV; Code(language="text") is not supported.
369
  csv_out = gr.Textbox(label="CSV (filename,labels)", lines=12)
370
  debug_out_b = gr.Textbox(label="Diagnostics", lines=6)
371
 
372
- def _run_batch(*args):
373
- r, v, c, d = run_batch(*args)
374
- return r, v, c, d
375
 
376
  run_batch_btn.click(
377
- _run_batch,
378
  inputs=[
379
- model_choice, custom_repo_id, system, context, zip_in,
380
  soft_cap, preprocess, lines_window, add_header, strip_smalltalk, load_4bit
381
  ],
382
  outputs=[repo_used_b, rev_used_b, csv_out, debug_out_b],
383
  )
384
 
385
  gr.Markdown(
386
- f"""
387
- - **HF_TOKEN detected:** {"✅ yes" if HF_TOKEN else "⚠️ no (only needed for gated models)"}
388
- - **Device:** {DEVICE}
389
- - **Cache dir:** `{CACHE_DIR}`
390
- """
391
  )
392
 
393
  if __name__ == "__main__":
 
1
  # app.py
2
+ # From Talk to Task — Batch & Single Task Extraction (Multilingual: EN/FR/DE/IT)
3
+ # Default model: Swiss Apertus instruct (set APERTUS_REPO below).
4
+ # Works on CPU / GPU / ZeroGPU. Uses a writable HF cache. JSON-only outputs.
 
5
 
6
  import os
7
  import io
8
  import re
 
9
  import time
10
  import json
11
  import zipfile
12
  from pathlib import Path
13
+ from typing import Dict, Tuple, Optional
14
 
15
  import gradio as gr
16
 
17
+ # --------------------------- CONFIG ---------------------------------
18
+
19
+ # <<< SET THIS TO YOUR APERTUS MODEL REPO ID >>>
20
+ # Example: "ApertusAI/swiss-apertus-7b-instruct" (replace with your actual repo id)
21
+ APERTUS_REPO = "swiss-ai/Apertus-8B-Instruct-2509"
22
+
23
+ # Optional: fallback open models (no gating) to sanity-check UI quickly
24
+ OPEN_FALLBACKS = [
25
+ "HuggingFaceH4/zephyr-7b-beta",
26
+ "Qwen/Qwen2.5-7B-Instruct",
27
+ "tiiuae/falcon-7b-instruct",
28
+ ]
29
+
30
+ PINNED_REVISIONS = {
31
+ # None => "main"
32
+ # Put your Apertus revision here if you want to pin it:
33
+ APERTUS_REPO: None,
34
+ "HuggingFaceH4/zephyr-7b-beta": None,
35
+ "Qwen/Qwen2.5-7B-Instruct": None,
36
+ "tiiuae/falcon-7b-instruct": None,
37
+ }
38
+
39
+ # Multilingual, but labels must be English and from this fixed set:
40
+ LABEL_SET = [
41
+ "plan_contact",
42
+ "schedule_meeting",
43
+ "update_contact_info_non_postal",
44
+ "update_contact_info_postal_address",
45
+ "update_kyc_activity",
46
+ "update_kyc_origin_of_assets",
47
+ "update_kyc_purpose_of_businessrelation",
48
+ "update_kyc_total_assets",
49
+ ]
50
+
51
+ SYSTEM_INSTRUCTIONS = (
52
+ "You are a task extraction assistant.\n"
53
+ "Input transcript language can be English, French, German, or Italian. "
54
+ "You MUST output valid JSON ONLY (no prose), with a single field:\n"
55
+ '"labels": a list of strings chosen ONLY from the set:\n'
56
+ f"{LABEL_SET}\n"
57
+ "Do not invent other fields. Do not translate labels. Return JSON only."
58
+ )
59
+
60
+ CONTEXT_GUIDE = (
61
+ "- plan_contact: contact without firm date/time\n"
62
+ "- schedule_meeting: explicit date/time/modality confirmed\n"
63
+ "- update_contact_info_non_postal: email/phone updates\n"
64
+ "- update_contact_info_postal_address: mailing address updates\n"
65
+ "- update_kyc_*: KYC updates (activity, purpose, origin of assets, total assets)\n"
66
+ )
67
+
68
+ # --------------------- WRITABLE HF CACHE -----------------------------
69
+
70
  HOME = Path(os.environ.get("HOME", "/home/user"))
71
  CACHE_DIR = HOME / ".cache" / "huggingface"
72
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
73
  os.environ.setdefault("HF_HOME", str(CACHE_DIR))
74
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster downloads when supported
75
+
76
+ HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
77
 
78
+ # -------------------- TRANSFORMERS / TORCH ---------------------------
79
 
 
80
  try:
81
  import torch
82
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
 
 
83
  except Exception as e:
84
  raise RuntimeError(
85
+ "Missing deps. In requirements.txt include: transformers>=4.41, torch, accelerate, huggingface_hub"
 
86
  ) from e
87
 
 
 
 
 
88
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
89
+ DTYPE_FALLBACK = (
90
+ torch.bfloat16 if (DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else
91
+ (torch.float16 if DEVICE == "cuda" else torch.float32)
92
+ )
93
 
94
+ # ZeroGPU presence (optional)
95
  try:
96
  import spaces # noqa: F401
97
  ON_ZERO_GPU = True
98
  except Exception:
99
  ON_ZERO_GPU = False
100
 
101
+ # -------------------------- HELPERS ---------------------------------
 
 
 
 
 
102
 
103
+ def _json_from_text(text: str) -> str:
104
+ s = text.strip()
105
+ if s.startswith("{") and s.endswith("}"):
106
+ return s
107
+ m = re.search(r"\{.*\}", s, re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return m.group(0) if m else '{"labels": []}'
109
 
110
+ def safe_json(s: str) -> dict:
111
  try:
112
  return json.loads(s)
113
  except Exception:
 
115
 
116
  def build_prompt(system: str, context: str, transcript: str) -> str:
117
  return (
118
+ f"### System\n{system}\n\n"
119
+ f"### Context\n{context}\n\n"
120
+ f"### Transcript\n{transcript}\n\n"
121
+ "### Output\nReturn JSON only."
122
  )
123
 
124
+ # -------------------------- MODEL -----------------------------------
125
+
126
  class HFModel:
127
  def __init__(
128
  self,
129
  repo_id: str,
130
+ revision: Optional[str],
131
+ token: Optional[str],
132
+ load_in_4bit: bool,
133
+ dtype
134
+ ):
 
135
  self.repo_id = repo_id
136
  self.revision = revision or "main"
 
137
  self.token = token
 
138
  self.load_in_4bit = load_in_4bit and (DEVICE == "cuda")
139
+ self.dtype = dtype
140
  self.tokenizer = None
141
  self.model = None
142
 
143
  def load(self):
144
+ qcfg = BitsAndBytesConfig(load_in_4bit=True) if self.load_in_4bit else None
145
+ self.tokenizer = AutoTokenizer.from_pretrained(
146
+ self.repo_id, revision=self.revision, token=self.token,
147
+ cache_dir=str(CACHE_DIR), use_fast=True, trust_remote_code=True
148
+ )
149
+ self.model = AutoModelForCausalLM.from_pretrained(
150
+ self.repo_id, revision=self.revision, token=self.token,
151
+ cache_dir=str(CACHE_DIR), trust_remote_code=True,
152
+ torch_dtype=self.dtype,
153
+ device_map="auto" if DEVICE == "cuda" else None,
154
+ quantization_config=qcfg, low_cpu_mem_usage=True
155
+ )
156
+ if DEVICE == "cpu":
157
+ self.model = self.model.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  @torch.inference_mode()
160
+ def generate(self, prompt: str, max_new_tokens=256, temperature=0.1) -> str:
161
  tok = self.tokenizer
162
  mdl = self.model
163
  if tok.pad_token is None:
164
  tok.pad_token = tok.eos_token
 
165
  inputs = tok(prompt, return_tensors="pt").to(mdl.device)
166
  out = mdl.generate(
167
  **inputs,
 
172
  pad_token_id=tok.eos_token_id,
173
  eos_token_id=tok.eos_token_id,
174
  )
175
+ decoded = tok.decode(out[0], skip_special_tokens=True)
176
+ gen = decoded[len(prompt):].strip() if decoded.startswith(prompt) else decoded
177
+ return _json_from_text(gen)
178
 
 
179
  _MODEL_CACHE: Dict[Tuple[str, Optional[str], bool], HFModel] = {}
180
 
181
  def get_model(repo_id: str, revision: Optional[str], load_in_4bit: bool) -> HFModel:
182
  key = (repo_id, revision, load_in_4bit)
183
  if key in _MODEL_CACHE:
184
  return _MODEL_CACHE[key]
185
+ mdl = HFModel(repo_id, revision, HF_TOKEN, load_in_4bit, DTYPE_FALLBACK)
186
+ try:
187
+ mdl.load()
188
+ except Exception as e:
189
+ raise RuntimeError(
190
+ f"Model load failed for {repo_id}@{revision or 'main'} — "
191
+ "If this is a gated/private model, ensure you accepted its license and set HF_TOKEN."
192
+ ) from e
193
+ _MODEL_CACHE[key] = mdl
194
+ return mdl
195
+
196
+ # ---------------------- INFERENCE ROUTES ----------------------------
197
 
 
198
  def run_single(
199
  model_choice: str,
200
  custom_repo_id: str,
 
207
  add_header: bool,
208
  strip_smalltalk: bool,
209
  load_in_4bit: bool,
210
+ ):
211
  debug = []
212
  t0 = time.perf_counter()
213
 
214
+ repo = (custom_repo_id or model_choice or APERTUS_REPO).strip()
215
  rev = PINNED_REVISIONS.get(repo, None)
216
+ debug.append(f"Repo: {repo} | Rev: {rev or 'main'} | Dev: {DEVICE} | 4bit: {load_in_4bit}")
217
 
218
  if preprocess:
219
  lines = [ln.rstrip() for ln in transcript.splitlines()]
220
  if strip_smalltalk:
221
+ lines = [ln for ln in lines if not re.search(r"\b(thanks?|merci|grazie|danke|bye|tsch(ü|u)ss|ciao|ok(ay)?)\b", ln, re.I)]
222
  transcript = "\n".join(lines[-32768:])
223
  if add_header:
224
  transcript = f"[EMAIL/MESSAGE SIGNAL]\n{transcript}"
 
232
 
233
  try:
234
  model = get_model(repo, rev, load_in_4bit)
235
+ raw = model.generate(prompt)
236
+ data = safe_json(raw)
237
  out_json = json.dumps(data, ensure_ascii=False)
238
+ debug.append(f"Done in {time.perf_counter()-t0:.2f}s")
239
  return repo, (rev or "main"), out_json, "\n".join(debug)
240
  except Exception as e:
241
  debug.append(f"ERROR: {e}")
242
  return repo, (rev or "main"), json.dumps({"labels": []}), "\n".join(debug)
243
 
 
244
  def run_batch(
245
  model_choice: str,
246
  custom_repo_id: str,
247
  system: str,
248
  context: str,
249
+ zip_file,
250
  soft_token_cap: int,
251
  preprocess: bool,
252
  lines_window: int,
253
  add_header: bool,
254
  strip_smalltalk: bool,
255
  load_in_4bit: bool,
256
+ ):
257
  debug = []
258
+ repo = (custom_repo_id or model_choice or APERTUS_REPO).strip()
259
  rev = PINNED_REVISIONS.get(repo, None)
260
 
261
  if not zip_file:
 
264
  try:
265
  z = zipfile.ZipFile(zip_file)
266
  names = [n for n in z.namelist() if n.lower().endswith(".txt")]
267
+ debug.append(f"Files: {len(names)}")
268
  except Exception as e:
269
  return repo, (rev or "main"), "filename,labels\n", f"Bad ZIP: {e}"
270
 
271
+ # Warm model once
272
  try:
273
+ _ = get_model(repo, rev, load_in_4bit)
274
  except Exception as e:
275
  return repo, (rev or "main"), "filename,labels\n", f"Model load error: {e}"
276
 
 
278
  for name in names:
279
  try:
280
  txt = z.read(name).decode("utf-8", errors="replace")
281
+ _, _, j, _ = run_single(
282
  model_choice, custom_repo_id, system, context, txt,
283
+ soft_token_cap, preprocess, lines_window, add_header, strip_smalltalk, load_in_4bit
 
284
  )
285
+ labels = safe_json(j).get("labels", [])
286
  rows.append(f"{name},{json.dumps(labels, ensure_ascii=False)}")
287
  except Exception as e:
288
  rows.append(f"{name},[] # error: {e}")
289
 
290
  return repo, (rev or "main"), "\n".join(rows), "\n".join(debug)
291
 
292
+ # ----------------------------- UI -----------------------------------
293
+
294
+ with gr.Blocks(title="From Talk to Task — Multilingual (EN/FR/DE/IT)") as demo:
295
  gr.Markdown(
296
+ f"""
297
+ # From Talk to Task — Multilingual (EN/FR/DE/IT)
298
 
299
+ **Default model:** `{APERTUS_REPO or 'PLEASE SET APERTUS_REPO'}`
300
+ You can override with a custom repo id below.
301
 
302
+ Pinned revisions: {", ".join([f"{k}@{v or 'main'}" for k, v in PINNED_REVISIONS.items()])}
303
+ """
 
 
304
  )
305
 
306
  with gr.Row():
307
  model_choice = gr.Dropdown(
308
+ [APERTUS_REPO] + OPEN_FALLBACKS, label="Model presets",
309
+ value=APERTUS_REPO if APERTUS_REPO else (OPEN_FALLBACKS[0] if OPEN_FALLBACKS else "")
 
310
  )
311
+ custom_repo = gr.Textbox(
312
  label="Custom model repo id (overrides preset)",
313
+ placeholder="e.g. ApertusAI/swiss-apertus-7b-instruct (requires license + HF_TOKEN if gated)"
314
  )
315
 
316
+ system = gr.Textbox(label="Instructions (System)", value=SYSTEM_INSTRUCTIONS, lines=6)
317
+ context = gr.Textbox(label="Context (User prefix)", value=CONTEXT_GUIDE, lines=6)
318
 
319
  with gr.Row():
320
  soft_cap = gr.Slider(1024, 32768, value=8192, step=1, label="Soft token cap")
 
326
  load_4bit = gr.Checkbox(value=False, label="Load in 4-bit (GPU only)")
327
 
328
  with gr.Tabs():
329
+ with gr.Tab("Single Transcript"):
330
+ transcript = gr.Textbox(label="Paste transcript (EN/FR/DE/IT)", lines=12)
331
  run_btn = gr.Button("Run (Single)", variant="primary")
332
  repo_used = gr.Textbox(label="Repo used", interactive=False)
333
  rev_used = gr.Textbox(label="Revision", interactive=False)
334
  json_out = gr.Code(label="JSON Output", language="json")
335
  debug_out = gr.Textbox(label="Diagnostics", lines=6)
336
 
337
+ def _single(*args):
338
+ return run_single(*args)
 
339
 
340
  run_btn.click(
341
+ _single,
342
  inputs=[
343
+ model_choice, custom_repo, system, context, transcript,
344
  soft_cap, preprocess, lines_window, add_header, strip_smalltalk, load_4bit
345
  ],
346
  outputs=[repo_used, rev_used, json_out, debug_out],
 
351
  run_batch_btn = gr.Button("Run (Batch)", variant="primary")
352
  repo_used_b = gr.Textbox(label="Repo used", interactive=False)
353
  rev_used_b = gr.Textbox(label="Revision", interactive=False)
 
354
  csv_out = gr.Textbox(label="CSV (filename,labels)", lines=12)
355
  debug_out_b = gr.Textbox(label="Diagnostics", lines=6)
356
 
357
+ def _batch(*args):
358
+ return run_batch(*args)
 
359
 
360
  run_batch_btn.click(
361
+ _batch,
362
  inputs=[
363
+ model_choice, custom_repo, system, context, zip_in,
364
  soft_cap, preprocess, lines_window, add_header, strip_smalltalk, load_4bit
365
  ],
366
  outputs=[repo_used_b, rev_used_b, csv_out, debug_out_b],
367
  )
368
 
369
  gr.Markdown(
370
+ f"- **HF_TOKEN:** {'✅ set' if HF_TOKEN else '⚠️ not set (only needed for gated/private)'} \n"
371
+ f"- **Device:** {DEVICE} \n"
372
+ f"- **Cache dir:** `{CACHE_DIR}`"
 
 
373
  )
374
 
375
  if __name__ == "__main__":