Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,232 +1,149 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Gradio application for the Swiss {ai} Weeks "From Talk to Task" challenge.
|
| 3 |
-
|
| 4 |
-
This app provides two modes of operation:
|
| 5 |
-
|
| 6 |
-
* **Single transcript** – Paste or upload a single conversation transcript
|
| 7 |
-
and the model will extract actionable tasks according to a
|
| 8 |
-
predefined list of labels. It outputs a human‑readable summary,
|
| 9 |
-
strict JSON, and diagnostic information (e.g. device, latency).
|
| 10 |
-
|
| 11 |
-
* **Batch evaluation** – Upload a ZIP archive containing one `.txt`
|
| 12 |
-
transcript per call and a matching `.json` file with the ground
|
| 13 |
-
truth labels. The app runs the model on each transcript, compares
|
| 14 |
-
the predictions against the true labels and computes the official
|
| 15 |
-
weighted score used by the challenge organisers. It also reports
|
| 16 |
-
precision, recall and F1, and provides a per‑sample results table
|
| 17 |
-
that can be downloaded as CSV.
|
| 18 |
-
|
| 19 |
-
The official allowed labels and evaluation function are taken from
|
| 20 |
-
the challenge repository README【235032860356166†L37-L51】【235032860356166†L76-L90】. False negatives are penalised twice as
|
| 21 |
-
heavily as false positives, so recall is especially important.
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
import os
|
| 25 |
-
import io
|
| 26 |
-
import re
|
| 27 |
-
import json
|
| 28 |
-
import time
|
| 29 |
-
import zipfile
|
| 30 |
-
from pathlib import Path
|
| 31 |
-
from typing import List, Dict, Any, Tuple, Optional
|
| 32 |
-
|
| 33 |
-
import gradio as gr
|
| 34 |
-
import numpy as np
|
| 35 |
-
import pandas as pd
|
| 36 |
-
import torch
|
| 37 |
-
from transformers import (
|
| 38 |
-
AutoTokenizer,
|
| 39 |
-
AutoModelForCausalLM,
|
| 40 |
-
BitsAndBytesConfig,
|
| 41 |
-
GenerationConfig,
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
# =============================================================================
|
| 46 |
-
# Configuration and Constants
|
| 47 |
-
# =============================================================================
|
| 48 |
-
|
| 49 |
-
# Cache directory for HuggingFace models
|
| 50 |
-
SPACE_CACHE = Path.home() / ".cache" / "huggingface"
|
| 51 |
-
SPACE_CACHE.mkdir(parents=True, exist_ok=True)
|
| 52 |
-
|
| 53 |
-
# Device selection
|
| 54 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 55 |
-
|
| 56 |
-
# Generation parameters tuned for speed/quality
|
| 57 |
-
GEN_CONFIG = GenerationConfig(
|
| 58 |
-
temperature=0.2,
|
| 59 |
-
top_p=0.9,
|
| 60 |
-
do_sample=False,
|
| 61 |
-
max_new_tokens=256,
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
# Official allowed task labels【235032860356166†L37-L51】
|
| 65 |
-
DEFAULT_ALLOWED_LABELS = [
|
| 66 |
-
"plan_contact",
|
| 67 |
-
"schedule_meeting",
|
| 68 |
-
"update_contact_info_non_postal",
|
| 69 |
-
"update_contact_info_postal_address",
|
| 70 |
-
"update_kyc_activity",
|
| 71 |
-
"update_kyc_origin_of_assets",
|
| 72 |
-
"update_kyc_purpose_of_businessrelation",
|
| 73 |
-
"update_kyc_total_assets",
|
| 74 |
-
]
|
| 75 |
-
|
| 76 |
-
# System and user prompt templates
|
| 77 |
-
SYSTEM_PROMPT = (
|
| 78 |
-
"You are a precise banking assistant that extracts ACTIONABLE TASKS "
|
| 79 |
-
"from client–advisor transcripts. Return STRICT JSON with fields: "
|
| 80 |
-
'{"labels": ["<Label1>", ...], "tasks": [{"label": "<Label1>", "explanation": "<why>", "evidence": "<span>"}]} '"
|
| 81 |
-
"Only use labels from the provided Allowed Labels list; if none apply, return an empty list."
|
| 82 |
-
)
|
| 83 |
|
| 84 |
-
|
| 85 |
-
```
|
| 86 |
-
{transcript}
|
| 87 |
-
```
|
| 88 |
-
|
| 89 |
-
Allowed Labels:
|
| 90 |
{allowed_labels_list}
|
| 91 |
|
| 92 |
-
|
| 93 |
-
{
|
| 94 |
-
"labels": ["LabelA", "LabelB", ...],
|
| 95 |
-
"tasks": [
|
| 96 |
-
{{"label": "LabelA", "explanation": "…", "evidence": "…"}},
|
| 97 |
-
{{"label": "LabelB", "explanation": "…", "evidence": "…"}}
|
| 98 |
-
]
|
| 99 |
-
}}
|
| 100 |
-
"""
|
| 101 |
-
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
| 106 |
|
|
|
|
|
|
|
|
|
|
| 107 |
def _now_ms() -> int:
|
| 108 |
-
"""Return the current time in milliseconds."""
|
| 109 |
return int(time.time() * 1000)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
def read_file_to_text(file: Optional[gr.File]) -> str:
|
| 113 |
-
"""
|
| 114 |
-
Read an uploaded file (txt/md/json) to a string. For JSON files,
|
| 115 |
-
return the value of the "transcript" field if present, otherwise
|
| 116 |
-
return the entire JSON as a compact string.
|
| 117 |
-
"""
|
| 118 |
-
if not file or not file.name:
|
| 119 |
-
return ""
|
| 120 |
-
name = file.name.lower()
|
| 121 |
-
data = file.read()
|
| 122 |
-
if name.endswith(".json"):
|
| 123 |
-
try:
|
| 124 |
-
obj = json.loads(data.decode("utf-8", errors="ignore"))
|
| 125 |
-
if isinstance(obj, dict) and "transcript" in obj:
|
| 126 |
-
return str(obj["transcript"])
|
| 127 |
-
return json.dumps(obj, ensure_ascii=False)
|
| 128 |
-
except Exception:
|
| 129 |
-
return data.decode("utf-8", errors="ignore")
|
| 130 |
-
return data.decode("utf-8", errors="ignore")
|
| 131 |
-
|
| 132 |
-
|
| 133 |
def normalize_labels(labels: List[str]) -> List[str]:
|
| 134 |
-
"""Deduplicate and strip whitespace from a list of labels."""
|
| 135 |
return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
|
| 136 |
|
| 137 |
-
|
| 138 |
def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
|
| 139 |
-
"""Map lowercase labels to their canonical names."""
|
| 140 |
return {lab.lower(): lab for lab in allowed}
|
| 141 |
|
| 142 |
-
|
| 143 |
def robust_json_extract(text: str) -> Dict[str, Any]:
|
| 144 |
-
"""
|
| 145 |
-
Extract the first JSON object from a string. Removes common
|
| 146 |
-
trailing comma mistakes. Returns an empty prediction if no JSON
|
| 147 |
-
object is found.
|
| 148 |
-
"""
|
| 149 |
if not text:
|
| 150 |
return {"labels": [], "tasks": []}
|
| 151 |
start, end = text.find("{"), text.rfind("}")
|
| 152 |
-
candidate = text[start:end
|
| 153 |
-
candidate = re.sub(r",\s*}\s*", "}", candidate)
|
| 154 |
-
candidate = re.sub(r",\s*]\s*", "]", candidate)
|
| 155 |
try:
|
| 156 |
return json.loads(candidate)
|
| 157 |
except Exception:
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
|
| 162 |
-
"""
|
| 163 |
-
Restrict predicted labels and tasks to those in the allowed list.
|
| 164 |
-
Case‑insensitive matching is performed and the canonical form is
|
| 165 |
-
returned. Duplicates are removed.
|
| 166 |
-
"""
|
| 167 |
out = {"labels": [], "tasks": []}
|
| 168 |
allowed_map = canonicalize_map(allowed)
|
| 169 |
-
#
|
| 170 |
-
filt_labels
|
| 171 |
for l in pred.get("labels", []) or []:
|
| 172 |
-
|
| 173 |
-
continue
|
| 174 |
-
k = l.strip().lower()
|
| 175 |
if k in allowed_map:
|
| 176 |
filt_labels.append(allowed_map[k])
|
| 177 |
filt_labels = normalize_labels(filt_labels)
|
| 178 |
-
#
|
| 179 |
filt_tasks = []
|
| 180 |
for t in pred.get("tasks", []) or []:
|
| 181 |
if not isinstance(t, dict):
|
| 182 |
continue
|
| 183 |
-
|
| 184 |
-
k = str(lbl).strip().lower()
|
| 185 |
if k in allowed_map:
|
| 186 |
new_t = dict(t)
|
| 187 |
new_t["label"] = allowed_map[k]
|
| 188 |
filt_tasks.append(new_t)
|
| 189 |
-
|
| 190 |
-
from_tasks = [tt["label"] for tt in filt_tasks if isinstance(tt.get("label"), str)]
|
| 191 |
-
merged = normalize_labels(list(set(filt_labels) | set(from_tasks)))
|
| 192 |
out["labels"] = merged
|
| 193 |
out["tasks"] = filt_tasks
|
| 194 |
return out
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
def
|
| 198 |
-
|
| 199 |
-
Keep only the last `max_tokens` tokens of a string according to
|
| 200 |
-
the provided tokenizer. This is useful for long transcripts to
|
| 201 |
-
reduce inference time.
|
| 202 |
-
"""
|
| 203 |
-
if max_tokens <= 0:
|
| 204 |
-
return text
|
| 205 |
-
tok = tokenizer(text, add_special_tokens=False)["input_ids"]
|
| 206 |
-
if len(tok) <= max_tokens:
|
| 207 |
return text
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
|
|
|
|
|
|
|
|
|
|
| 215 |
class ModelWrapper:
|
| 216 |
-
"""
|
| 217 |
-
Wraps a HuggingFace model and tokenizer, with optional 4‑bit
|
| 218 |
-
quantisation. Instances are cached per model and quantisation
|
| 219 |
-
setting.
|
| 220 |
-
"""
|
| 221 |
def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
|
| 222 |
self.repo_id = repo_id
|
| 223 |
self.hf_token = hf_token
|
| 224 |
self.load_in_4bit = load_in_4bit
|
| 225 |
-
self.tokenizer
|
| 226 |
-
self.model
|
| 227 |
|
| 228 |
def load(self):
|
| 229 |
-
# 4‑bit quantisation config
|
| 230 |
qcfg = None
|
| 231 |
if self.load_in_4bit and DEVICE == "cuda":
|
| 232 |
qcfg = BitsAndBytesConfig(
|
|
@@ -235,527 +152,433 @@ class ModelWrapper:
|
|
| 235 |
bnb_4bit_compute_dtype=torch.float16,
|
| 236 |
bnb_4bit_use_double_quant=True,
|
| 237 |
)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
token=self.hf_token,
|
| 242 |
-
cache_dir=str(SPACE_CACHE),
|
| 243 |
-
trust_remote_code=True,
|
| 244 |
-
use_fast=True,
|
| 245 |
)
|
| 246 |
-
if
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
self.repo_id,
|
| 251 |
-
token=self.hf_token,
|
| 252 |
-
cache_dir=str(SPACE_CACHE),
|
| 253 |
trust_remote_code=True,
|
| 254 |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
| 255 |
device_map="auto" if DEVICE == "cuda" else None,
|
| 256 |
-
low_cpu_mem_usage=True,
|
| 257 |
-
quantization_config=qcfg,
|
| 258 |
attn_implementation="sdpa",
|
| 259 |
)
|
|
|
|
|
|
|
| 260 |
|
| 261 |
@torch.inference_mode()
|
| 262 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 263 |
-
"""
|
| 264 |
-
Generate text from system and user prompts. Chat templates are
|
| 265 |
-
used if defined on the tokenizer.
|
| 266 |
-
"""
|
| 267 |
-
assert self.tokenizer is not None and self.model is not None
|
| 268 |
-
# Chat template support
|
| 269 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
input_ids = self.tokenizer.apply_chat_template(
|
| 275 |
-
messages, add_generation_prompt=True, return_tensors="pt"
|
| 276 |
).to(self.model.device)
|
| 277 |
else:
|
| 278 |
-
text = f"<s>[SYSTEM]{system_prompt}[/SYSTEM][USER]{user_prompt}[/USER]"
|
| 279 |
-
|
|
|
|
| 280 |
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
|
| 281 |
out_ids = self.model.generate(
|
| 282 |
-
**
|
| 283 |
generation_config=GEN_CONFIG,
|
| 284 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 285 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 286 |
)
|
| 287 |
return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 288 |
|
| 289 |
-
|
| 290 |
-
# Model cache keyed by (repo_id, quantisation)
|
| 291 |
_MODEL_CACHE: Dict[str, ModelWrapper] = {}
|
| 292 |
-
|
| 293 |
-
|
| 294 |
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
|
| 295 |
-
"""Retrieve or load a ModelWrapper from the cache."""
|
| 296 |
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
|
| 297 |
if key not in _MODEL_CACHE:
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
_MODEL_CACHE[key] =
|
| 301 |
return _MODEL_CACHE[key]
|
| 302 |
|
| 303 |
-
|
| 304 |
-
#
|
| 305 |
-
#
|
| 306 |
-
# =============================================================================
|
| 307 |
-
|
| 308 |
def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
|
| 309 |
-
|
| 310 |
-
Official weighted score for the challenge【235032860356166†L76-L90】. False negatives
|
| 311 |
-
incur double the penalty of false positives. Returns a score
|
| 312 |
-
between 0.0 and 1.0, where 1.0 is perfect.
|
| 313 |
-
"""
|
| 314 |
-
ALLOWED_LABELS = DEFAULT_ALLOWED_LABELS
|
| 315 |
LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
|
| 316 |
FN_PENALTY = 2.0
|
| 317 |
FP_PENALTY = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
if len(y_true) != len(y_pred):
|
| 319 |
-
raise ValueError(f"y_true and y_pred
|
|
|
|
| 320 |
n_samples = len(y_true)
|
| 321 |
-
n_labels = len(
|
| 322 |
y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
|
| 323 |
y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
|
| 324 |
-
for i, labels in enumerate(y_true):
|
| 325 |
-
for l in labels:
|
| 326 |
-
if l not in LABEL_TO_IDX:
|
| 327 |
-
raise ValueError(f"Invalid true label '{l}'")
|
| 328 |
-
y_true_binary[i, LABEL_TO_IDX[l]] = 1
|
| 329 |
-
for i, labels in enumerate(y_pred):
|
| 330 |
-
for l in labels:
|
| 331 |
-
if l not in LABEL_TO_IDX:
|
| 332 |
-
raise ValueError(f"Invalid predicted label '{l}'")
|
| 333 |
-
y_pred_binary[i, LABEL_TO_IDX[l]] = 1
|
| 334 |
-
false_negatives = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
|
| 335 |
-
false_positives = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
|
| 336 |
-
weighted_errors = FN_PENALTY * false_negatives + FP_PENALTY * false_positives
|
| 337 |
-
max_errors_per_sample = FN_PENALTY * np.sum(y_true_binary, axis=1) + FP_PENALTY * (
|
| 338 |
-
n_labels - np.sum(y_true_binary, axis=1)
|
| 339 |
-
)
|
| 340 |
-
per_sample_scores = np.where(
|
| 341 |
-
max_errors_per_sample > 0,
|
| 342 |
-
1.0 - (weighted_errors / max_errors_per_sample),
|
| 343 |
-
1.0,
|
| 344 |
-
)
|
| 345 |
-
final_score = float(np.mean(per_sample_scores))
|
| 346 |
-
return max(0.0, min(1.0, final_score))
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
# =============================================================================
|
| 350 |
-
# Prediction Utilities
|
| 351 |
-
# =============================================================================
|
| 352 |
-
|
| 353 |
-
def predict_labels_for_text(
|
| 354 |
-
model: ModelWrapper,
|
| 355 |
-
transcript: str,
|
| 356 |
-
allowed: List[str],
|
| 357 |
-
max_tokens: int,
|
| 358 |
-
) -> List[str]:
|
| 359 |
-
"""
|
| 360 |
-
Predict labels for a transcript string using the given model.
|
| 361 |
-
The transcript is truncated to the last `max_tokens` tokens to
|
| 362 |
-
reduce inference time. Only labels in `allowed` are returned.
|
| 363 |
-
"""
|
| 364 |
-
# Truncate transcript
|
| 365 |
-
truncated = truncate_tokens(model.tokenizer, transcript, max_tokens)
|
| 366 |
-
allowed_list_str = "\n".join(f"- {lab}" for lab in allowed)
|
| 367 |
-
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 368 |
-
transcript=truncated,
|
| 369 |
-
allowed_labels_list=allowed_list_str,
|
| 370 |
-
)
|
| 371 |
-
raw_out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 372 |
-
parsed = robust_json_extract(raw_out)
|
| 373 |
-
filtered = restrict_to_allowed(parsed, allowed)
|
| 374 |
-
return filtered.get("labels", []) or []
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
def run_single(
|
| 382 |
transcript_text: str,
|
| 383 |
-
transcript_file:
|
|
|
|
| 384 |
allowed_labels_text: str,
|
| 385 |
model_repo: str,
|
| 386 |
use_4bit: bool,
|
| 387 |
max_input_tokens: int,
|
| 388 |
hf_token: str,
|
| 389 |
) -> Tuple[str, str, str, str]:
|
| 390 |
-
|
| 391 |
-
Process a single transcript and return (summary, json_output,
|
| 392 |
-
diagnostics, raw_model_output). The summary is human‑readable,
|
| 393 |
-
json_output is the strict JSON string, diagnostics contains
|
| 394 |
-
performance information, and raw_model_output is the unfiltered
|
| 395 |
-
model response for debugging.
|
| 396 |
-
"""
|
| 397 |
t0 = _now_ms()
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
raw_text =
|
|
|
|
| 401 |
if not raw_text:
|
| 402 |
-
return (
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
# Determine allowed labels
|
| 409 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 410 |
-
allowed = normalize_labels(user_allowed or
|
| 411 |
-
|
|
|
|
| 412 |
try:
|
| 413 |
-
model = get_model(model_repo, hf_token.strip() or None, use_4bit)
|
| 414 |
except Exception as e:
|
| 415 |
-
return (
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
truncated = truncate_tokens(model.tokenizer, raw_text, max_input_tokens)
|
| 424 |
-
allowed_list_str = "\n".join(f"- {lab}" for lab in allowed)
|
| 425 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 426 |
-
transcript=
|
| 427 |
allowed_labels_list=allowed_list_str,
|
|
|
|
| 428 |
)
|
|
|
|
| 429 |
# Generate
|
|
|
|
| 430 |
try:
|
| 431 |
-
|
| 432 |
except Exception as e:
|
| 433 |
-
return (
|
| 434 |
-
"",
|
| 435 |
-
"",
|
| 436 |
-
f"Generation error: {e}",
|
| 437 |
-
json.dumps({"labels": [], "tasks": []}, indent=2),
|
| 438 |
-
)
|
| 439 |
t2 = _now_ms()
|
| 440 |
-
|
| 441 |
-
|
|
|
|
| 442 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
labs = filtered.get("labels", [])
|
| 445 |
tasks = filtered.get("tasks", [])
|
| 446 |
-
|
| 447 |
-
if labs:
|
| 448 |
-
summ_lines.append("Detected labels:\n - " + "\n - ".join(labs))
|
| 449 |
-
else:
|
| 450 |
-
summ_lines.append("Detected labels: (none)")
|
| 451 |
if tasks:
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
ev = t.get("evidence", "")
|
| 457 |
-
trimmed = ev[:140] + ("…" if len(ev) > 140 else "")
|
| 458 |
-
summ_lines.append(f"• [{lab}] {expl} | evidence: {trimmed}")
|
| 459 |
else:
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
# =============================================================================
|
| 477 |
-
# Batch Evaluation Handler
|
| 478 |
-
# =============================================================================
|
| 479 |
|
| 480 |
def run_batch(
|
| 481 |
-
zip_file:
|
| 482 |
-
|
| 483 |
model_repo: str,
|
| 484 |
use_4bit: bool,
|
| 485 |
max_input_tokens: int,
|
| 486 |
hf_token: str,
|
| 487 |
-
|
| 488 |
) -> Tuple[str, str, str, pd.DataFrame, str]:
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
""
|
| 494 |
-
if
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
try:
|
| 501 |
-
model = get_model(model_repo, hf_token.strip() or None, use_4bit)
|
| 502 |
except Exception as e:
|
| 503 |
return (f"Model load failed: {e}", "", "", pd.DataFrame(), "")
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
y_pred_list: List[List[str]] = []
|
| 537 |
-
result_rows: List[Dict[str, Any]] = []
|
| 538 |
-
total_tp = total_fp = total_fn = 0
|
| 539 |
-
# Iterate through samples
|
| 540 |
-
for stem, txt_path, truth_path in paired:
|
| 541 |
-
# Read transcript
|
| 542 |
-
try:
|
| 543 |
-
with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
|
| 544 |
-
transcript = f.read().strip()
|
| 545 |
-
except Exception:
|
| 546 |
-
transcript = ""
|
| 547 |
-
# Read ground truth labels
|
| 548 |
-
true_labels: List[str] = []
|
| 549 |
-
if truth_path and truth_path.is_file():
|
| 550 |
-
try:
|
| 551 |
-
with open(truth_path, "r", encoding="utf-8", errors="ignore") as f:
|
| 552 |
-
obj = json.load(f)
|
| 553 |
-
if isinstance(obj, dict) and "labels" in obj:
|
| 554 |
-
true_labels = [str(l).strip() for l in obj["labels"] if isinstance(l, str)]
|
| 555 |
-
elif isinstance(obj, list):
|
| 556 |
-
true_labels = [str(l).strip() for l in obj if isinstance(l, str)]
|
| 557 |
-
except Exception:
|
| 558 |
-
true_labels = []
|
| 559 |
-
# Predict labels
|
| 560 |
-
pred_labels: List[str] = []
|
| 561 |
-
if transcript:
|
| 562 |
try:
|
| 563 |
-
|
|
|
|
|
|
|
| 564 |
except Exception:
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
#
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
""
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
)
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
label="Model Repository",
|
| 645 |
-
choices=[
|
| 646 |
-
"swiss-ai/Apertus-8B-Instruct-2509",
|
| 647 |
-
"meta-llama/Meta-Llama-3-8B-Instruct",
|
| 648 |
-
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 649 |
-
],
|
| 650 |
-
value="swiss-ai/Apertus-8B-Instruct-2509",
|
| 651 |
-
)
|
| 652 |
-
use_4bit = gr.Checkbox(
|
| 653 |
-
label="Use 4-bit quantisation (GPU only)", value=True
|
| 654 |
-
)
|
| 655 |
-
max_input_tokens = gr.Slider(
|
| 656 |
-
label="Max input tokens (truncate from end)",
|
| 657 |
-
minimum=1024,
|
| 658 |
-
maximum=8192,
|
| 659 |
-
step=512,
|
| 660 |
-
value=4096,
|
| 661 |
-
)
|
| 662 |
-
hf_token = gr.Textbox(
|
| 663 |
-
label="HF_TOKEN (for gated/private models)",
|
| 664 |
-
type="password",
|
| 665 |
-
value=os.environ.get("HF_TOKEN", ""),
|
| 666 |
-
)
|
| 667 |
-
single_button = gr.Button("Run Extraction", variant="primary")
|
| 668 |
-
with gr.Row():
|
| 669 |
-
summary = gr.Textbox(label="Summary", lines=12)
|
| 670 |
-
json_out = gr.Code(label="Strict JSON Output", language="json")
|
| 671 |
-
with gr.Row():
|
| 672 |
-
diag = gr.Textbox(label="Diagnostics", lines=6)
|
| 673 |
-
raw_out = gr.Textbox(label="Raw Model Output", lines=6)
|
| 674 |
-
# Hook up single button
|
| 675 |
-
single_button.click(
|
| 676 |
-
fn=run_single,
|
| 677 |
-
inputs=[
|
| 678 |
-
transcript_text,
|
| 679 |
-
transcript_file,
|
| 680 |
-
allowed_labels_text,
|
| 681 |
-
model_repo,
|
| 682 |
-
use_4bit,
|
| 683 |
-
max_input_tokens,
|
| 684 |
-
hf_token,
|
| 685 |
-
],
|
| 686 |
-
outputs=[summary, json_out, diag, raw_out],
|
| 687 |
-
)
|
| 688 |
-
with gr.Tab("Batch Evaluation"):
|
| 689 |
-
with gr.Row():
|
| 690 |
-
with gr.Column(scale=3):
|
| 691 |
-
zip_input = gr.File(
|
| 692 |
-
label="ZIP of transcripts and labels", file_types=[".zip"], type="filepath"
|
| 693 |
-
)
|
| 694 |
-
batch_allowed_labels = gr.Textbox(
|
| 695 |
-
label="Allowed Labels (one per line; leave blank for defaults)",
|
| 696 |
-
lines=8,
|
| 697 |
-
)
|
| 698 |
-
max_files_slider = gr.Slider(
|
| 699 |
-
label="Max files to process (0 = no limit)",
|
| 700 |
-
minimum=0,
|
| 701 |
-
maximum=1000,
|
| 702 |
-
step=1,
|
| 703 |
-
value=0,
|
| 704 |
-
)
|
| 705 |
-
with gr.Column(scale=2):
|
| 706 |
-
batch_model_repo = gr.Dropdown(
|
| 707 |
-
label="Model Repository",
|
| 708 |
-
choices=[
|
| 709 |
-
"swiss-ai/Apertus-8B-Instruct-2509",
|
| 710 |
-
"meta-llama/Meta-Llama-3-8B-Instruct",
|
| 711 |
-
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 712 |
-
],
|
| 713 |
-
value="swiss-ai/Apertus-8B-Instruct-2509",
|
| 714 |
-
)
|
| 715 |
-
batch_use_4bit = gr.Checkbox(
|
| 716 |
-
label="Use 4-bit quantisation (GPU only)", value=True
|
| 717 |
-
)
|
| 718 |
-
batch_max_input_tokens = gr.Slider(
|
| 719 |
-
label="Max input tokens (truncate from end)",
|
| 720 |
-
minimum=1024,
|
| 721 |
-
maximum=8192,
|
| 722 |
-
step=512,
|
| 723 |
-
value=4096,
|
| 724 |
-
)
|
| 725 |
-
batch_hf_token = gr.Textbox(
|
| 726 |
-
label="HF_TOKEN (for gated/private models)",
|
| 727 |
-
type="password",
|
| 728 |
-
value=os.environ.get("HF_TOKEN", ""),
|
| 729 |
-
)
|
| 730 |
-
batch_button = gr.Button("Run Batch Evaluation", variant="primary")
|
| 731 |
-
# Outputs
|
| 732 |
-
batch_score = gr.Textbox(label="Score")
|
| 733 |
-
batch_metrics = gr.Textbox(label="Recall / Precision / F1")
|
| 734 |
-
batch_extra = gr.Textbox(label="Summary", lines=2)
|
| 735 |
-
batch_df = gr.Dataframe(label="Per‑sample results", interactive=True, wrap=True)
|
| 736 |
-
batch_download = gr.File(label="Download results (CSV)")
|
| 737 |
-
# Hook up batch button
|
| 738 |
-
def on_batch(zip_file, allowed_text, repo, use4, max_tok, token, max_f):
|
| 739 |
-
score, metrics, extra, df, csv_path = run_batch(
|
| 740 |
-
zip_file, allowed_text, repo, use4, max_tok, token, int(max_f)
|
| 741 |
)
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
|
| 759 |
if __name__ == "__main__":
|
| 760 |
-
demo
|
| 761 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
Allowed Labels (canonical; use only these):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
{allowed_labels_list}
|
| 4 |
|
| 5 |
+
Context cues (keywords/phrases that often indicate each label):
|
| 6 |
+
{keyword_context}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
Instructions:
|
| 9 |
+
1) Identify EVERY concrete task implied by the conversation.
|
| 10 |
+
2) Choose ONE label from Allowed Labels for each task (or none if truly inapplicable).
|
| 11 |
+
3) Return STRICT JSON only in the exact schema described by the system prompt.
|
| 12 |
+
"""
|
| 13 |
|
| 14 |
+
# =========================
|
| 15 |
+
# Utilities
|
| 16 |
+
# =========================
|
| 17 |
def _now_ms() -> int:
|
|
|
|
| 18 |
return int(time.time() * 1000)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def normalize_labels(labels: List[str]) -> List[str]:
|
|
|
|
| 21 |
return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
|
| 22 |
|
|
|
|
| 23 |
def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
|
|
|
|
| 24 |
return {lab.lower(): lab for lab in allowed}
|
| 25 |
|
|
|
|
| 26 |
def robust_json_extract(text: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if not text:
|
| 28 |
return {"labels": [], "tasks": []}
|
| 29 |
start, end = text.find("{"), text.rfind("}")
|
| 30 |
+
candidate = text[start:end+1] if (start != -1 and end != -1 and end > start) else text
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
return json.loads(candidate)
|
| 33 |
except Exception:
|
| 34 |
+
candidate = re.sub(r",\s*}", "}", candidate)
|
| 35 |
+
candidate = re.sub(r",\s*]", "]", candidate)
|
| 36 |
+
try:
|
| 37 |
+
return json.loads(candidate)
|
| 38 |
+
except Exception:
|
| 39 |
+
return {"labels": [], "tasks": []}
|
| 40 |
|
| 41 |
def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
out = {"labels": [], "tasks": []}
|
| 43 |
allowed_map = canonicalize_map(allowed)
|
| 44 |
+
# labels
|
| 45 |
+
filt_labels = []
|
| 46 |
for l in pred.get("labels", []) or []:
|
| 47 |
+
k = str(l).strip().lower()
|
|
|
|
|
|
|
| 48 |
if k in allowed_map:
|
| 49 |
filt_labels.append(allowed_map[k])
|
| 50 |
filt_labels = normalize_labels(filt_labels)
|
| 51 |
+
# tasks
|
| 52 |
filt_tasks = []
|
| 53 |
for t in pred.get("tasks", []) or []:
|
| 54 |
if not isinstance(t, dict):
|
| 55 |
continue
|
| 56 |
+
k = str(t.get("label", "")).strip().lower()
|
|
|
|
| 57 |
if k in allowed_map:
|
| 58 |
new_t = dict(t)
|
| 59 |
new_t["label"] = allowed_map[k]
|
| 60 |
filt_tasks.append(new_t)
|
| 61 |
+
merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
|
|
|
|
|
|
|
| 62 |
out["labels"] = merged
|
| 63 |
out["tasks"] = filt_tasks
|
| 64 |
return out
|
| 65 |
|
| 66 |
+
# =========================
|
| 67 |
+
# Default pre-processing
|
| 68 |
+
# =========================
|
| 69 |
+
# These are conservative; they remove boilerplate that appears in many files
|
| 70 |
+
# and does not affect tasks. You can toggle this in the UI.
|
| 71 |
+
_DISCLAIMER_PATTERNS = [
|
| 72 |
+
r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
|
| 73 |
+
r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
|
| 74 |
+
r"(?is)^\s*this message \(including any attachments\).+?(?:\n{2,}|$)",
|
| 75 |
+
]
|
| 76 |
+
_FOOTER_PATTERNS = [
|
| 77 |
+
r"(?is)\n+kind regards[^\n]*\n.*$", r"(?is)\n+best regards[^\n]*\n.*$",
|
| 78 |
+
r"(?is)\n+sent from my.*$", r"(?is)\n+ubs ag.*$",
|
| 79 |
+
]
|
| 80 |
+
_TIMESTAMP_SPEAKER = [
|
| 81 |
+
r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02]
|
| 82 |
+
r"^\s*(advisor|client)\s*:\s*", # Advisor: / Client:
|
| 83 |
+
r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1:
|
| 84 |
+
]
|
| 85 |
|
| 86 |
+
def clean_transcript(text: str) -> str:
|
| 87 |
+
if not text:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
return text
|
| 89 |
+
s = text
|
| 90 |
+
|
| 91 |
+
# Remove common timestamps and speaker prefixes (line-wise)
|
| 92 |
+
lines = []
|
| 93 |
+
for ln in s.splitlines():
|
| 94 |
+
ln2 = ln
|
| 95 |
+
for pat in _TIMESTAMP_SPEAKER:
|
| 96 |
+
ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
|
| 97 |
+
lines.append(ln2)
|
| 98 |
+
s = "\n".join(lines)
|
| 99 |
+
|
| 100 |
+
# Remove top disclaimers
|
| 101 |
+
for pat in _DISCLAIMER_PATTERNS:
|
| 102 |
+
s = re.sub(pat, "", s).strip()
|
| 103 |
+
|
| 104 |
+
# Remove trailing footers/signatures
|
| 105 |
+
for pat in _FOOTER_PATTERNS:
|
| 106 |
+
s = re.sub(pat, "", s)
|
| 107 |
+
|
| 108 |
+
# Collapse repeated whitespace
|
| 109 |
+
s = re.sub(r"[ \t]+", " ", s)
|
| 110 |
+
s = re.sub(r"\n{3,}", "\n\n", s).strip()
|
| 111 |
+
return s
|
| 112 |
+
|
| 113 |
+
def read_text_from_file(file: gr.File) -> str:
|
| 114 |
+
if not file or not file.name:
|
| 115 |
+
return ""
|
| 116 |
+
name = file.name.lower()
|
| 117 |
+
data = file.read()
|
| 118 |
+
if name.endswith(".json"):
|
| 119 |
+
try:
|
| 120 |
+
obj = json.loads(data.decode("utf-8", errors="ignore"))
|
| 121 |
+
if isinstance(obj, dict) and "transcript" in obj:
|
| 122 |
+
return str(obj["transcript"])
|
| 123 |
+
return json.dumps(obj, ensure_ascii=False)
|
| 124 |
+
except Exception:
|
| 125 |
+
return data.decode("utf-8", errors="ignore")
|
| 126 |
+
else:
|
| 127 |
+
return data.decode("utf-8", errors="ignore")
|
| 128 |
|
| 129 |
+
def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
|
| 130 |
+
toks = tokenizer(text, add_special_tokens=False)["input_ids"]
|
| 131 |
+
if len(toks) <= max_tokens:
|
| 132 |
+
return text
|
| 133 |
+
return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
|
| 134 |
|
| 135 |
+
# =========================
|
| 136 |
+
# HF model wrapper
|
| 137 |
+
# =========================
|
| 138 |
class ModelWrapper:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
|
| 140 |
self.repo_id = repo_id
|
| 141 |
self.hf_token = hf_token
|
| 142 |
self.load_in_4bit = load_in_4bit
|
| 143 |
+
self.tokenizer = None
|
| 144 |
+
self.model = None
|
| 145 |
|
| 146 |
def load(self):
|
|
|
|
| 147 |
qcfg = None
|
| 148 |
if self.load_in_4bit and DEVICE == "cuda":
|
| 149 |
qcfg = BitsAndBytesConfig(
|
|
|
|
| 152 |
bnb_4bit_compute_dtype=torch.float16,
|
| 153 |
bnb_4bit_use_double_quant=True,
|
| 154 |
)
|
| 155 |
+
tok = AutoTokenizer.from_pretrained(
|
| 156 |
+
self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
|
| 157 |
+
trust_remote_code=True, use_fast=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
)
|
| 159 |
+
if tok.pad_token is None and tok.eos_token:
|
| 160 |
+
tok.pad_token = tok.eos_token
|
| 161 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 162 |
+
self.repo_id, token=self.hf_token, cache_dir=str(SPACE_CACHE),
|
|
|
|
|
|
|
|
|
|
| 163 |
trust_remote_code=True,
|
| 164 |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
| 165 |
device_map="auto" if DEVICE == "cuda" else None,
|
| 166 |
+
low_cpu_mem_usage=True, quantization_config=qcfg,
|
|
|
|
| 167 |
attn_implementation="sdpa",
|
| 168 |
)
|
| 169 |
+
self.tokenizer = tok
|
| 170 |
+
self.model = model
|
| 171 |
|
| 172 |
@torch.inference_mode()
|
| 173 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 175 |
+
msgs = [{"role": "system", "content": system_prompt},
|
| 176 |
+
{"role": "user", "content": user_prompt}]
|
| 177 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 178 |
+
msgs, add_generation_prompt=True, return_tensors="pt"
|
|
|
|
|
|
|
| 179 |
).to(self.model.device)
|
| 180 |
else:
|
| 181 |
+
text = f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n"
|
| 182 |
+
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
| 183 |
+
|
| 184 |
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
|
| 185 |
out_ids = self.model.generate(
|
| 186 |
+
**inputs,
|
| 187 |
generation_config=GEN_CONFIG,
|
| 188 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 189 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 190 |
)
|
| 191 |
return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 192 |
|
|
|
|
|
|
|
| 193 |
_MODEL_CACHE: Dict[str, ModelWrapper] = {}
|
|
|
|
|
|
|
| 194 |
def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> ModelWrapper:
|
|
|
|
| 195 |
key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}"
|
| 196 |
if key not in _MODEL_CACHE:
|
| 197 |
+
m = ModelWrapper(repo_id, hf_token, load_in_4bit)
|
| 198 |
+
m.load()
|
| 199 |
+
_MODEL_CACHE[key] = m
|
| 200 |
return _MODEL_CACHE[key]
|
| 201 |
|
| 202 |
+
# =========================
|
| 203 |
+
# Official evaluation (from README)
|
| 204 |
+
# =========================
|
|
|
|
|
|
|
| 205 |
def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
|
| 206 |
+
ALLOWED_LABELS = OFFICIAL_LABELS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
|
| 208 |
FN_PENALTY = 2.0
|
| 209 |
FP_PENALTY = 1.0
|
| 210 |
+
|
| 211 |
+
def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
|
| 212 |
+
if not isinstance(sample_labels, list):
|
| 213 |
+
raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
|
| 214 |
+
# dedupe
|
| 215 |
+
seen, uniq = set(), []
|
| 216 |
+
for label in sample_labels:
|
| 217 |
+
if not isinstance(label, str):
|
| 218 |
+
raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
|
| 219 |
+
if label in seen:
|
| 220 |
+
raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
|
| 221 |
+
seen.add(label); uniq.append(label)
|
| 222 |
+
# validity
|
| 223 |
+
valid = []
|
| 224 |
+
for label in uniq:
|
| 225 |
+
if label not in ALLOWED_LABELS:
|
| 226 |
+
raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
|
| 227 |
+
valid.append(label)
|
| 228 |
+
return valid
|
| 229 |
+
|
| 230 |
if len(y_true) != len(y_pred):
|
| 231 |
+
raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
|
| 232 |
+
|
| 233 |
n_samples = len(y_true)
|
| 234 |
+
n_labels = len(OFFICIAL_LABELS)
|
| 235 |
y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
|
| 236 |
y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
+
for i, sample_labels in enumerate(y_true):
|
| 239 |
+
for label in _process_sample_labels(sample_labels, f"y_true[{i}]"):
|
| 240 |
+
y_true_binary[i, LABEL_TO_IDX[label]] = 1
|
| 241 |
+
|
| 242 |
+
for i, sample_labels in enumerate(y_pred):
|
| 243 |
+
for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
|
| 244 |
+
y_pred_binary[i, LABEL_TO_IDX[label]] = 1
|
| 245 |
+
|
| 246 |
+
fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1)
|
| 247 |
+
fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1)
|
| 248 |
+
weighted = 2.0 * fn + 1.0 * fp
|
| 249 |
+
max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
|
| 250 |
+
per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
|
| 251 |
+
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 252 |
+
|
| 253 |
+
# =========================
|
| 254 |
+
# Inference helpers
|
| 255 |
+
# =========================
|
| 256 |
+
def build_keyword_context(allowed: List[str]) -> str:
|
| 257 |
+
parts = []
|
| 258 |
+
for lab in allowed:
|
| 259 |
+
kws = LABEL_KEYWORDS.get(lab, [])
|
| 260 |
+
if kws:
|
| 261 |
+
parts.append(f"- {lab}: " + ", ".join(kws))
|
| 262 |
+
else:
|
| 263 |
+
parts.append(f"- {lab}: (no default cues)")
|
| 264 |
+
return "\n".join(parts)
|
| 265 |
|
| 266 |
def run_single(
|
| 267 |
transcript_text: str,
|
| 268 |
+
transcript_file: gr.File,
|
| 269 |
+
use_cleaning: bool,
|
| 270 |
allowed_labels_text: str,
|
| 271 |
model_repo: str,
|
| 272 |
use_4bit: bool,
|
| 273 |
max_input_tokens: int,
|
| 274 |
hf_token: str,
|
| 275 |
) -> Tuple[str, str, str, str]:
|
| 276 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
t0 = _now_ms()
|
| 278 |
+
|
| 279 |
+
# Get transcript
|
| 280 |
+
raw_text = read_text_from_file(transcript_file) if transcript_file else (transcript_text or "")
|
| 281 |
+
raw_text = (raw_text or "").strip()
|
| 282 |
if not raw_text:
|
| 283 |
+
return "", "", "No transcript provided.", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 284 |
+
|
| 285 |
+
# Cleaning
|
| 286 |
+
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 287 |
+
|
| 288 |
+
# Allowed labels
|
|
|
|
| 289 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 290 |
+
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 291 |
+
|
| 292 |
+
# Model
|
| 293 |
try:
|
| 294 |
+
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 295 |
except Exception as e:
|
| 296 |
+
return "", "", f"Model load failed: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
|
| 297 |
+
|
| 298 |
+
# Truncate
|
| 299 |
+
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 300 |
+
|
| 301 |
+
# Build prompt
|
| 302 |
+
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 303 |
+
keyword_ctx = build_keyword_context(allowed)
|
|
|
|
|
|
|
| 304 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 305 |
+
transcript=trunc,
|
| 306 |
allowed_labels_list=allowed_list_str,
|
| 307 |
+
keyword_context=keyword_ctx,
|
| 308 |
)
|
| 309 |
+
|
| 310 |
# Generate
|
| 311 |
+
t1 = _now_ms()
|
| 312 |
try:
|
| 313 |
+
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 314 |
except Exception as e:
|
| 315 |
+
return "", "", f"Generation error: {e}", json.dumps({"labels": [], "tasks": []}, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
t2 = _now_ms()
|
| 317 |
+
|
| 318 |
+
# Parse + filter
|
| 319 |
+
parsed = robust_json_extract(out)
|
| 320 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 321 |
+
|
| 322 |
+
# Diagnostics
|
| 323 |
+
diag = "\n".join([
|
| 324 |
+
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 325 |
+
f"Model: {model_repo}",
|
| 326 |
+
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 327 |
+
f"Tokens (input, approx): ≤ {max_input_tokens}",
|
| 328 |
+
f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
|
| 329 |
+
f"Allowed labels: {', '.join(allowed)}",
|
| 330 |
+
])
|
| 331 |
+
|
| 332 |
+
# Summary
|
| 333 |
labs = filtered.get("labels", [])
|
| 334 |
tasks = filtered.get("tasks", [])
|
| 335 |
+
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
if tasks:
|
| 337 |
+
summary += "\n\nTasks:\n" + "\n".join(
|
| 338 |
+
f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:140]}{'…' if len(t.get('evidence',''))>140 else ''}"
|
| 339 |
+
for t in tasks
|
| 340 |
+
)
|
|
|
|
|
|
|
|
|
|
| 341 |
else:
|
| 342 |
+
summary += "\n\nTasks: (none)"
|
| 343 |
+
|
| 344 |
+
return summary, json.dumps(filtered, indent=2, ensure_ascii=False), diag, out.strip()
|
| 345 |
+
|
| 346 |
+
# =========================
|
| 347 |
+
# Batch mode (ZIP with transcripts + truths)
|
| 348 |
+
# =========================
|
| 349 |
+
def read_zip(fileobj: io.BytesIO, exdir: Path) -> List[Path]:
|
| 350 |
+
exdir.mkdir(parents=True, exist_ok=True)
|
| 351 |
+
with zipfile.ZipFile(fileobj) as zf:
|
| 352 |
+
zf.extractall(exdir)
|
| 353 |
+
out = []
|
| 354 |
+
for p in exdir.rglob("*"):
|
| 355 |
+
if p.is_file():
|
| 356 |
+
out.append(p)
|
| 357 |
+
return out
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
def run_batch(
|
| 360 |
+
zip_file: gr.File,
|
| 361 |
+
use_cleaning: bool,
|
| 362 |
model_repo: str,
|
| 363 |
use_4bit: bool,
|
| 364 |
max_input_tokens: int,
|
| 365 |
hf_token: str,
|
| 366 |
+
limit_files: int,
|
| 367 |
) -> Tuple[str, str, str, pd.DataFrame, str]:
|
| 368 |
+
|
| 369 |
+
if not zip_file:
|
| 370 |
+
return ("No ZIP provided.", "", "", pd.DataFrame(), "")
|
| 371 |
+
|
| 372 |
+
work = Path("/tmp/batch")
|
| 373 |
+
if work.exists():
|
| 374 |
+
for p in work.rglob("*"):
|
| 375 |
+
try: p.unlink()
|
| 376 |
+
except Exception: pass
|
| 377 |
+
try: work.rmdir()
|
| 378 |
+
except Exception: pass
|
| 379 |
+
work.mkdir(parents=True, exist_ok=True)
|
| 380 |
+
|
| 381 |
+
# Unzip
|
| 382 |
+
data = zip_file.read()
|
| 383 |
+
files = read_zip(io.BytesIO(data), work)
|
| 384 |
+
|
| 385 |
+
# Gather pairs by stem
|
| 386 |
+
txts: Dict[str, Path] = {}
|
| 387 |
+
gts: Dict[str, Path] = {}
|
| 388 |
+
for p in files:
|
| 389 |
+
if p.suffix.lower() == ".txt":
|
| 390 |
+
txts[p.stem] = p
|
| 391 |
+
elif p.suffix.lower() == ".json":
|
| 392 |
+
gts[p.stem] = p
|
| 393 |
+
|
| 394 |
+
stems = sorted(txts.keys())
|
| 395 |
+
if limit_files > 0:
|
| 396 |
+
stems = stems[:limit_files]
|
| 397 |
+
if not stems:
|
| 398 |
+
return ("No .txt transcripts found in ZIP.", "", "", pd.DataFrame(), "")
|
| 399 |
+
|
| 400 |
+
# Model
|
| 401 |
try:
|
| 402 |
+
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 403 |
except Exception as e:
|
| 404 |
return (f"Model load failed: {e}", "", "", pd.DataFrame(), "")
|
| 405 |
+
|
| 406 |
+
allowed = OFFICIAL_LABELS[:] # fixed for scoring
|
| 407 |
+
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 408 |
+
keyword_ctx = build_keyword_context(allowed)
|
| 409 |
+
|
| 410 |
+
y_true, y_pred = [], []
|
| 411 |
+
rows = []
|
| 412 |
+
t_start = _now_ms()
|
| 413 |
+
|
| 414 |
+
for stem in stems:
|
| 415 |
+
raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
|
| 416 |
+
text = clean_transcript(raw) if use_cleaning else raw
|
| 417 |
+
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 418 |
+
|
| 419 |
+
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 420 |
+
transcript=trunc,
|
| 421 |
+
allowed_labels_list=allowed_list_str,
|
| 422 |
+
keyword_context=keyword_ctx,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
t0 = _now_ms()
|
| 426 |
+
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 427 |
+
t1 = _now_ms()
|
| 428 |
+
|
| 429 |
+
parsed = robust_json_extract(out)
|
| 430 |
+
filtered = restrict_to_allowed(parsed, allowed)
|
| 431 |
+
pred_labels = filtered.get("labels", [])
|
| 432 |
+
y_pred.append(pred_labels)
|
| 433 |
+
|
| 434 |
+
# Ground truth (optional)
|
| 435 |
+
gt_labels = []
|
| 436 |
+
if stem in gts:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
try:
|
| 438 |
+
gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
|
| 439 |
+
if isinstance(gt_obj, dict) and "labels" in gt_obj and isinstance(gt_obj["labels"], list):
|
| 440 |
+
gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
|
| 441 |
except Exception:
|
| 442 |
+
pass
|
| 443 |
+
y_true.append(gt_labels)
|
| 444 |
+
|
| 445 |
+
# FP/FN counts for table
|
| 446 |
+
gt_set = set(gt_labels)
|
| 447 |
+
pr_set = set(pred_labels)
|
| 448 |
+
tp = sorted(gt_set & pr_set)
|
| 449 |
+
fp = sorted(pr_set - gt_set)
|
| 450 |
+
fn = sorted(gt_set - pr_set)
|
| 451 |
+
|
| 452 |
+
rows.append({
|
| 453 |
+
"file": stem,
|
| 454 |
+
"true_labels": ", ".join(gt_labels),
|
| 455 |
+
"pred_labels": ", ".join(pred_labels),
|
| 456 |
+
"TP": len(tp), "FP": len(fp), "FN": len(fn),
|
| 457 |
+
"gen_ms": t1 - t0
|
| 458 |
+
})
|
| 459 |
+
|
| 460 |
+
# Metrics
|
| 461 |
+
# If there is no ground truth in the ZIP, we still compute a table and skip score.
|
| 462 |
+
have_truth = any(len(v) > 0 for v in y_true)
|
| 463 |
+
score = evaluate_predictions(y_true, y_pred) if have_truth else None
|
| 464 |
+
|
| 465 |
+
df = pd.DataFrame(rows).sort_values(["FN", "FP", "file"])
|
| 466 |
+
diag = [
|
| 467 |
+
f"Processed files: {len(stems)}",
|
| 468 |
+
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 469 |
+
f"Model: {model_repo}",
|
| 470 |
+
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 471 |
+
f"Tokens (input, approx): ≤ {max_input_tokens}",
|
| 472 |
+
f"Batch time: {_now_ms()-t_start} ms",
|
| 473 |
+
]
|
| 474 |
+
if have_truth and score is not None:
|
| 475 |
+
# Simple derived metrics
|
| 476 |
+
total_tp = int(df["TP"].sum())
|
| 477 |
+
total_fp = int(df["FP"].sum())
|
| 478 |
+
total_fn = int(df["FN"].sum())
|
| 479 |
+
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 1.0
|
| 480 |
+
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 1.0
|
| 481 |
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
|
| 482 |
+
diag += [
|
| 483 |
+
f"Official weighted score (0–1): {score:.3f}",
|
| 484 |
+
f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}",
|
| 485 |
+
f"Total TP={total_tp} FP={total_fp} FN={total_fn}",
|
| 486 |
+
]
|
| 487 |
+
diag_str = "\n".join(diag)
|
| 488 |
+
|
| 489 |
+
# CSV preview and data URL
|
| 490 |
+
csv_buf = io.StringIO()
|
| 491 |
+
df.to_csv(csv_buf, index=False)
|
| 492 |
+
csv_data = csv_buf.getvalue()
|
| 493 |
+
|
| 494 |
+
return ("Batch done.", diag_str, csv_data, df, csv_data)
|
| 495 |
+
|
| 496 |
+
# =========================
|
| 497 |
+
# UI
|
| 498 |
+
# =========================
|
| 499 |
+
MODEL_CHOICES = [
|
| 500 |
+
"swiss-ai/Apertus-8B-Instruct-2509",
|
| 501 |
+
"meta-llama/Meta-Llama-3-8B-Instruct",
|
| 502 |
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 503 |
+
]
|
| 504 |
+
|
| 505 |
+
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
| 506 |
+
gr.Markdown("# Talk2Task — Task Extraction (UBS Challenge)")
|
| 507 |
+
gr.Markdown(
|
| 508 |
+
"This tool extracts challenge labels from transcripts. "
|
| 509 |
+
"Use **Single** for quick tests; use **Batch** to score a ZIP with transcripts + truths. "
|
| 510 |
+
"_Note: False negatives are penalised twice as much as false positives in the official metric; "
|
| 511 |
+
"we bias for recall._"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
with gr.Tab("Single transcript"):
|
| 515 |
+
with gr.Row():
|
| 516 |
+
with gr.Column(scale=3):
|
| 517 |
+
file = gr.File(
|
| 518 |
+
label="Drag & drop transcript (.txt / .md / .json)",
|
| 519 |
+
file_types=[".txt", ".md", ".json"],
|
| 520 |
+
type="filepath",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
)
|
| 522 |
+
text = gr.Textbox(label="Or paste transcript", lines=14)
|
| 523 |
+
use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, footers)", value=True)
|
| 524 |
+
labels_text = gr.Textbox(
|
| 525 |
+
label="Allowed Labels (one per line; leave empty to use official list)",
|
| 526 |
+
value="",
|
| 527 |
+
lines=8,
|
| 528 |
+
)
|
| 529 |
+
with gr.Column(scale=2):
|
| 530 |
+
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 531 |
+
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 532 |
+
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
|
| 533 |
+
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 534 |
+
run_btn = gr.Button("Run Extraction", variant="primary")
|
| 535 |
+
|
| 536 |
+
with gr.Row():
|
| 537 |
+
summary = gr.Textbox(label="Summary", lines=12)
|
| 538 |
+
json_out = gr.Code(label="Strict JSON Output", language="json")
|
| 539 |
+
with gr.Row():
|
| 540 |
+
diag = gr.Textbox(label="Diagnostics", lines=8)
|
| 541 |
+
raw = gr.Textbox(label="Raw Model Output", lines=8)
|
| 542 |
+
|
| 543 |
+
run_btn.click(
|
| 544 |
+
fn=run_single,
|
| 545 |
+
inputs=[text, file, use_cleaning, labels_text, repo, use_4bit, max_tokens, hf_token],
|
| 546 |
+
outputs=[summary, json_out, diag, raw],
|
| 547 |
+
)
|
| 548 |
|
| 549 |
+
with gr.Tab("Batch evaluation"):
|
| 550 |
+
with gr.Row():
|
| 551 |
+
with gr.Column(scale=3):
|
| 552 |
+
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 553 |
+
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 554 |
+
with gr.Column(scale=2):
|
| 555 |
+
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 556 |
+
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 557 |
+
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=4096)
|
| 558 |
+
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 559 |
+
limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
|
| 560 |
+
run_batch_btn = gr.Button("Run Batch", variant="primary")
|
| 561 |
+
|
| 562 |
+
with gr.Row():
|
| 563 |
+
status = gr.Textbox(label="Status", lines=1)
|
| 564 |
+
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=10)
|
| 565 |
+
|
| 566 |
+
with gr.Row():
|
| 567 |
+
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, times)", interactive=False)
|
| 568 |
+
csv_out = gr.File(label="Download CSV (click to save)", interactive=False)
|
| 569 |
+
|
| 570 |
+
def _save_csv(csv_text: str) -> str:
|
| 571 |
+
if not csv_text:
|
| 572 |
+
return ""
|
| 573 |
+
out_path = Path("/tmp/batch_results.csv")
|
| 574 |
+
out_path.write_text(csv_text, encoding="utf-8")
|
| 575 |
+
return str(out_path)
|
| 576 |
+
|
| 577 |
+
run_batch_btn.click(
|
| 578 |
+
fn=run_batch,
|
| 579 |
+
inputs=[zip_in, use_cleaning_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
|
| 580 |
+
outputs=[status, diag_b, csv_out, df_out, gr.Textbox(visible=False)],
|
| 581 |
+
)
|
| 582 |
|
| 583 |
if __name__ == "__main__":
|
| 584 |
+
demo.launch()
|
|
|