"""Gradio UI for the Data Curation Workbench MVP.""" from __future__ import annotations import json import logging import os from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple import gradio as gr from huggingface_hub import inspect_job, run_job from huggingface_hub.utils import HfHubHTTPError from utils.config import ConfigError, load_space_config from utils.data import DatasetSpec, load_candidate_catalog from utils.hub import ensure_results_repo, ensure_uploaded_dataset CONFIG_ERROR: str | None try: CONFIG = load_space_config() CONFIG_ERROR = None except ConfigError as err: CONFIG = None CONFIG_ERROR = str(err) CATALOG_PATH = Path("catalog/candidates.json") def load_candidates() -> List[DatasetSpec]: if not CATALOG_PATH.exists(): return [] return load_candidate_catalog(CATALOG_PATH) CANDIDATES = load_candidates() LOGGER = logging.getLogger(__name__) def environment_diagnostics() -> Tuple[Dict[str, Any], Dict[str, Any]]: """Surface blocking configuration issues and advisories in the UI.""" blocking: List[str] = [] warnings: List[str] = [] if CONFIG_ERROR: blocking.append(f"Configuration error: {CONFIG_ERROR}") service_token = os.getenv("SERVICE_HF_TOKEN", "").strip() if not service_token: blocking.append( "`SERVICE_HF_TOKEN` is not set. Uploading datasets and launching jobs " "requires a service token with write access." ) oauth_client = os.getenv("OAUTH_CLIENT_ID", "").strip() oauth_secret = os.getenv("OAUTH_CLIENT_SECRET", "").strip() if oauth_client and not oauth_secret: warnings.append("`OAUTH_CLIENT_SECRET` is missing; Hugging Face login will be disabled.") elif oauth_secret and not oauth_client: warnings.append("`OAUTH_CLIENT_ID` is missing; Hugging Face login will be disabled.") elif not oauth_client and not oauth_secret: warnings.append( "OAuth credentials are not configured. Users will need a service token or " "manual dataset ids to submit experiments." ) message_lines: List[str] = [] if blocking: message_lines.append("**Blocking issues detected:**") message_lines.extend(f"- {issue}" for issue in blocking) if warnings: message_lines.append("**Warnings:**") message_lines.extend(f"- {warn}" for warn in warnings) display_message = "\n".join(message_lines) banner_update = gr.update(value=display_message, visible=bool(display_message)) run_button_state = gr.update(interactive=not blocking) return banner_update, run_button_state DEFAULT_MODEL = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_SIZES = [5000, 10000, 20000] TASK_OPTIONS: List[Tuple[str, str]] = [ ("classification", "classification"), ("qa", "qa"), ("pretraining", "language model pretraining"), ("speech_recognition", "speech recognition"), ] TASK_LABEL_TO_VALUE: Dict[str, str] = {label: value for value, label in TASK_OPTIONS} TASK_VALUE_TO_LABEL: Dict[str, str] = {value: label for value, label in TASK_OPTIONS} TASK_METRIC_CHOICES: Dict[str, List[str]] = { "classification": ["loss", "f1", "exact_match"], "qa": ["loss", "f1", "exact_match"], "pretraining": ["loss", "perplexity"], "speech_recognition": ["loss", "Word Error Rate (WER)"], } TASK_METRIC_DEFAULT: Dict[str, List[str]] = { "classification": ["f1"], "qa": ["f1"], "pretraining": ["perplexity"], "speech_recognition": ["Word Error Rate (WER)"], } TASK_MODEL_CHOICES: Dict[str, List[str]] = { "classification": [DEFAULT_MODEL], "qa": [DEFAULT_MODEL], "pretraining": [DEFAULT_MODEL], "speech_recognition": [ "anton-l/emformer-base-librispeech", "train from scratch", ], } TASK_BENCHMARK_CHOICES: Dict[str, List[str]] = { "speech_recognition": [ "sanchit-gandhi/tedlium-data.test", "openslr/librispeech_asr.test.clean", ] } BLOCK_CSS = """ .page-title { background: transparent !important; padding: 0 !important; margin-bottom: calc(1.5 * var(--spacing-lg)); } .page-description { background: transparent !important; padding: 0 !important; margin-top: calc(-0.5 * var(--spacing-md)); margin-bottom: calc(1.5 * var(--spacing-xl)); color: var(--secondary-text-color); } .spec-card { margin-bottom: calc(2 * var(--spacing-xl)); } .spec-card .card-surface { background: var(--panel-background) !important; border-radius: var(--radius-lg); box-shadow: var(--shadow-sm); padding: var(--spacing-lg) calc(1.5 * var(--spacing-lg)); } .spec-card .card-surface > :not(:last-child) { margin-bottom: var(--spacing-md); } .spec-heading, .spec-label { background: transparent !important; box-shadow: none !important; padding: 0 !important; margin: 0; } .spec-heading h3 { margin: 0; font-size: 1.05rem; } .spec-label { font-weight: 600; color: var(--body-text-color); } .run-button, .run-button button, button.run-button { background-color: #f97316 !important; border-color: #f97316 !important; color: #ffffff !important; } .run-button:hover, .run-button:focus, .run-button button:hover, .run-button button:focus, button.run-button:hover, button.run-button:focus { background-color: #ea580c !important; border-color: #ea580c !important; color: #ffffff !important; } """ def _task_value_from_label(label: str) -> str: try: return TASK_LABEL_TO_VALUE[label] except KeyError as exc: raise ValueError(f"Unsupported task label '{label}'.") from exc def _task_label_from_value(value: str) -> str: try: return TASK_VALUE_TO_LABEL[value] except KeyError as exc: raise ValueError(f"Unsupported task '{value}'.") from exc def _normalize_task_value(task: str) -> str: if task in TASK_VALUE_TO_LABEL: return task return _task_value_from_label(task) def _model_choices_for_task(task: str) -> List[str]: try: choices = TASK_MODEL_CHOICES[task] except KeyError as exc: raise ValueError(f"Unsupported task '{task}'.") from exc if not choices: raise ValueError(f"No base models configured for task '{task}'.") return choices def _target_label_for_task(task: str) -> str: if task == "speech_recognition": return "Target dataset size for full-scale training (hours)" return "Target dataset size for full-scale training" def _coerce_int_list(values: Iterable[Any] | None) -> List[int]: if values is None: return [] ints: List[int] = [] for value in values: if value is None: raise ValueError("Mixture sizes contain an empty entry.") if isinstance(value, str): stripped = value.strip() if not stripped: raise ValueError("Mixture sizes contain an empty entry.") value = stripped try: ints.append(int(value)) except (TypeError, ValueError) as exc: raise ValueError(f"Invalid mixture size '{value}'.") from exc return ints def candidate_choices_for_task(task: str) -> List[str]: return [spec.dataset_id for spec in CANDIDATES if spec.task == task] def metrics_for_task(task: str) -> Tuple[List[str], List[str]]: if task not in TASK_METRIC_CHOICES: raise ValueError(f"Unsupported task '{task}'.") choices = TASK_METRIC_CHOICES[task] defaults = TASK_METRIC_DEFAULT.get(task, [choices[0]]) return choices, defaults def on_task_change( selected_task_label: str, ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: task_value = _task_value_from_label(selected_task_label) metric_choices, metric_defaults = metrics_for_task(task_value) candidate_choices = candidate_choices_for_task(task_value) model_choices = _model_choices_for_task(task_value) benchmark_choices = TASK_BENCHMARK_CHOICES.get(task_value, []) return ( gr.update(choices=metric_choices, value=metric_defaults), gr.update(choices=candidate_choices, value=[]), gr.update(choices=model_choices, value=model_choices[0]), gr.update(choices=benchmark_choices, value=[]), gr.update(label=_target_label_for_task(task_value)), ) def submit_experiments( d0_files: List[Any], d0_id: str, task: str, model: str, metrics: List[str], dk_list: List[str], sizes: List[Any], target_size: float, test_files: Optional[List[Any]], test_id: str, public_benchmarks: Optional[List[str]] = None, profile: Optional[gr.OAuthProfile] = None, oauth: Optional[gr.OAuthToken] = None, ) -> List[Dict[str, Any]]: if CONFIG_ERROR: raise RuntimeError(f"Configuration error: {CONFIG_ERROR}") assert CONFIG is not None task_value = _normalize_task_value(task) task_label = _task_label_from_value(task_value) selected_public_benchmarks = list(public_benchmarks or []) try: CONFIG.require_service_token() except ConfigError as exc: raise RuntimeError( "`SERVICE_HF_TOKEN` is required to submit experiments. Configure the secret " "in the Space settings before retrying." ) from exc metric_choices, _ = metrics_for_task(task_value) if not metrics: raise ValueError("Select at least one metric for the chosen task.") invalid_metrics = [metric for metric in metrics if metric not in metric_choices] if invalid_metrics: invalid = ", ".join(invalid_metrics) raise ValueError(f"Unsupported metric(s) for task '{task_label}': {invalid}.") selected_metrics = list(metrics) selected_sizes = _coerce_int_list(sizes) if not selected_sizes: raise ValueError("Choose at least one mixture size.") if any(size <= 0 for size in selected_sizes): raise ValueError("Mixture sizes must be positive integers.") if target_size is None or (isinstance(target_size, str) and not target_size.strip()): raise ValueError("Provide a target corpus size for scaling-law prediction.") try: target_size_int = int(target_size) except (TypeError, ValueError) as exc: raise ValueError("Target size must be an integer value.") from exc if target_size_int <= 0: raise ValueError("Target size must be a positive integer.") if not dk_list: raise ValueError("Select at least one candidate dataset to evaluate.") user_token = getattr(oauth, "token", None) d0_repo = ensure_uploaded_dataset( CONFIG, upload_files=d0_files or [], d0_dataset_id=d0_id.strip() or None, user_token=user_token, ) results_repo = ensure_results_repo(CONFIG) test_repo: Optional[str] = None test_id_clean = test_id.strip() if test_files: test_repo = ensure_uploaded_dataset( CONFIG, upload_files=test_files, d0_dataset_id=test_id_clean or None, user_token=user_token, ) elif test_id_clean: test_repo = test_id_clean jobs: List[Dict[str, Any]] = [] env = CONFIG.job_env(user_token=user_token) job_token = CONFIG.resolve_job_token(user_token=user_token) job_namespace = CONFIG.resolve_job_namespace(job_token) for dk in dk_list: command = [ "python", "jobs/run_experiment.py", "--model", model, "--task", task_value, "--d0", d0_repo, "--dk", dk, "--metrics", *selected_metrics, "--sizes", *[str(size) for size in selected_sizes], "--target_size", str(target_size_int), "--results_repo", results_repo, ] if test_repo: command.extend(["--test_dataset", test_repo]) try: job = run_job( image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel", command=command, flavor="a10g-small", timeout=7200, env=env, token=job_token, namespace=job_namespace, ) except HfHubHTTPError as exc: if getattr(exc.response, "status_code", None) == 403: raise RuntimeError( "`SERVICE_HF_TOKEN` lacks `jobs:write` or is not allowed to create Jobs in the selected " "namespace. Sign in with Hugging Face or configure a token with Jobs permissions." ) from exc raise jobs.append( { "id": job.id, "dk": dk, "url": getattr(job, "url", ""), "status": job.status, "artifacts": "", "benchmarks": selected_public_benchmarks, } ) return jobs def submit_with_feedback( current_jobs: List[Dict[str, Any]], d0_files: List[Any], d0_id: str, task: str, model: str, metrics: List[str], dk_list: List[str], sizes: List[Any], target_size: float, test_files: Optional[List[Any]] = None, test_id: str = "", public_benchmarks: Optional[List[str]] = None, profile: Optional[gr.OAuthProfile] = None, oauth: Optional[gr.OAuthToken] = None, ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: task_value = _normalize_task_value(task) try: jobs = submit_experiments( d0_files=d0_files, d0_id=d0_id, task=task_value, model=model, metrics=metrics, dk_list=dk_list, sizes=sizes, target_size=target_size, public_benchmarks=public_benchmarks, test_files=test_files, test_id=test_id, profile=profile, oauth=oauth, ) except Exception as exc: # noqa: BLE001 - surface the exception to the UI LOGGER.exception("Experiment submission failed: %s", exc) message = f"❌ Experiment submission failed: {exc}" banner = gr.update(value=message, visible=True) return current_jobs, banner banner = gr.update(value="", visible=False) return jobs, banner def poll_jobs(jobs_state: List[Dict[str, Any]]) -> List[Dict[str, Any]]: if CONFIG_ERROR or not jobs_state: return jobs_state updated: List[Dict[str, Any]] = [] results_repo = CONFIG.results_repo if CONFIG else None for job in jobs_state: try: info = inspect_job(job["id"]) status = getattr(info, "status", job.get("status", "unknown")) except Exception: # best effort poll status = job.get("status", "unknown") artifacts = job.get("artifacts", "") if status == "completed" and not artifacts: repo_hint = results_repo or "(results repo)" artifacts = f"{repo_hint}/experiments/{job['id']}" updated.append({**job, "status": status, "artifacts": artifacts}) return updated def render_table(jobs: List[Dict[str, Any]]) -> List[List[str]]: rows: List[List[str]] = [] for job in jobs or []: rows.append([ job.get("id", ""), job.get("dk", ""), job.get("status", ""), job.get("url", ""), job.get("artifacts", ""), ]) return rows def build_interface() -> gr.Blocks: with gr.Blocks( title="Data Acquisition and Curation Intelligence", css=BLOCK_CSS, ) as demo: login_enabled = bool( os.getenv("OAUTH_CLIENT_ID", "").strip() and os.getenv("OAUTH_CLIENT_SECRET", "").strip() ) if login_enabled: gr.LoginButton() else: gr.Markdown( "⚠️ Hugging Face login is disabled until `OAUTH_CLIENT_ID` and `OAUTH_CLIENT_SECRET` are configured.", visible=True, ) status_banner = gr.Markdown("", visible=False) initial_task_value = "classification" initial_task_label = _task_label_from_value(initial_task_value) metric_choices, metric_defaults = metrics_for_task(initial_task_value) candidate_choices = candidate_choices_for_task(initial_task_value) model_choices = _model_choices_for_task(initial_task_value) benchmark_choices = TASK_BENCHMARK_CHOICES.get(initial_task_value, []) gr.Markdown( "# Data Acquisition and Curation Intelligence", elem_classes="page-title", ) gr.Markdown( ( "Estimate the return on external datasets before you spend on full-scale training.\n" "Select which dataset(s) you want to try, launch small proxy training runs, and get a performance forecast on full-scale training." ), elem_classes="page-description", ) with gr.Group(elem_classes="spec-card"): with gr.Column(elem_classes="card-surface"): gr.HTML( "