Spaces:
Sleeping
Sleeping
| """Core helpers for the QuickStart Hugging Face repo assistant.""" | |
| from __future__ import annotations | |
| import html | |
| import inspect | |
| import os | |
| import re | |
| import tempfile | |
| import textwrap | |
| import zipfile | |
| from pathlib import Path | |
| from typing import Any | |
| from huggingface_hub import HfApi | |
| from huggingface_hub.utils import HfHubHTTPError | |
| VALID_REPO_TYPES = {"model", "dataset", "space"} | |
| RE_REPO_SEGMENT = re.compile(r"^(?!.*(?:--|\.\.))[A-Za-z0-9][A-Za-z0-9_.-]{0,95}$") | |
| SENSITIVE_FILENAME_PATTERNS = [ | |
| r"(^|/)\.env$", | |
| r"secrets?", | |
| r"token", | |
| r"api[_-]?key", | |
| r"credentials?", | |
| r"id_rsa", | |
| r"\.pem$", | |
| r"\.p12$", | |
| r"\.kdbx$", | |
| ] | |
| def esc(value: Any) -> str: | |
| """HTML-escape values before injecting them into custom Gradio HTML.""" | |
| return html.escape("" if value is None else str(value), quote=True) | |
| def norm_type(value: str | None) -> str: | |
| repo_type = (value or "model").strip().lower() | |
| return repo_type if repo_type in VALID_REPO_TYPES else "model" | |
| def norm_id(value: str | None) -> str: | |
| return (value or "").strip().strip("/") | |
| def is_valid_repo_id(repo_id: str) -> bool: | |
| repo_id = (repo_id or "").strip() | |
| parts = repo_id.split("/") | |
| if len(parts) not in {1, 2}: | |
| return False | |
| return all( | |
| RE_REPO_SEGMENT.match(part) and not part.startswith(("-", ".")) and not part.endswith(("-", ".")) | |
| for part in parts | |
| ) | |
| def human_bytes(num_bytes: int | None) -> str: | |
| if not isinstance(num_bytes, int) or num_bytes <= 0: | |
| return "N/A" | |
| units = ["B", "KB", "MB", "GB", "TB"] | |
| value = float(num_bytes) | |
| unit_index = 0 | |
| while value >= 1024 and unit_index < len(units) - 1: | |
| value /= 1024 | |
| unit_index += 1 | |
| return f"{value:.2f} {units[unit_index]}" | |
| def safe_str(value: Any, max_chars: int = 500) -> str: | |
| text = "" if value is None else str(value) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| if len(text) > max_chars: | |
| return text[: max_chars - 3] + "..." | |
| return text | |
| def py_literal(value: Any) -> str: | |
| """Return a safe Python string literal for generated snippets.""" | |
| return repr("" if value is None else str(value)) | |
| def parse_hf_input(user_input: str) -> tuple[str, str]: | |
| """Parse a Hugging Face URL, typed repo path, or plain owner/repo ID.""" | |
| value = (user_input or "").strip() | |
| if not value: | |
| return "model", "" | |
| if "huggingface.co" in value or "hf.co" in value: | |
| scoped_match = re.search(r"(?:huggingface\.co|hf\.co)/(datasets|spaces)/([^?#]+)", value) | |
| if scoped_match: | |
| repo_type = "dataset" if scoped_match.group(1) == "datasets" else "space" | |
| repo_id = _strip_hf_file_path(scoped_match.group(2)) | |
| return repo_type, repo_id | |
| model_match = re.search(r"(?:huggingface\.co|hf\.co)/([^?#]+)", value) | |
| if model_match: | |
| repo_id = _strip_hf_file_path(model_match.group(1)) | |
| return "model", repo_id | |
| if value.startswith("datasets/"): | |
| return "dataset", value.replace("datasets/", "", 1).strip("/") | |
| if value.startswith("spaces/"): | |
| return "space", value.replace("spaces/", "", 1).strip("/") | |
| return "model", value.strip("/") | |
| def _strip_hf_file_path(path: str) -> str: | |
| path = (path or "").strip("/") | |
| path = re.split(r"/(tree|blob|resolve|raw|viewer|discussions)/", path)[0].strip("/") | |
| return path | |
| def hf_url(repo_type: str, repo_id: str) -> str: | |
| repo_type = norm_type(repo_type) | |
| repo_id = norm_id(repo_id) | |
| if repo_type == "dataset": | |
| return f"https://huggingface.co/datasets/{repo_id}" | |
| if repo_type == "space": | |
| return f"https://huggingface.co/spaces/{repo_id}" | |
| return f"https://huggingface.co/{repo_id}" | |
| def safe_hf_error(error: HfHubHTTPError) -> str: | |
| status = getattr(getattr(error, "response", None), "status_code", "N/A") | |
| message = getattr(error, "server_message", None) or str(error) | |
| return f"Hugging Face Error: {status} - {safe_str(message, 500)}" | |
| def call_with_supported_kwargs(fn: Any, *args: Any, **kwargs: Any) -> Any: | |
| """Call SDK functions with only supported kwargs without swallowing API errors.""" | |
| try: | |
| signature = inspect.signature(fn) | |
| except (TypeError, ValueError): | |
| return fn(*args, **kwargs) | |
| allowed = set(signature.parameters) | |
| supported_kwargs = {key: value for key, value in kwargs.items() if key in allowed} | |
| return fn(*args, **supported_kwargs) | |
| def extract_file_entries(info_obj: Any) -> list[dict[str, Any]]: | |
| entries: list[dict[str, Any]] = [] | |
| siblings = getattr(info_obj, "siblings", None) or [] | |
| for sibling in siblings: | |
| path = getattr(sibling, "rfilename", None) or getattr(sibling, "path", None) | |
| if not path: | |
| continue | |
| size = getattr(sibling, "size", None) | |
| if size is None: | |
| lfs = getattr(sibling, "lfs", None) | |
| size = getattr(lfs, "size", None) if lfs is not None else None | |
| entries.append({"path": str(path), "size": int(size) if isinstance(size, int) else None}) | |
| return entries | |
| def files_risk_report(files: list[dict[str, Any]]) -> dict[str, Any]: | |
| paths = [item.get("path", "") for item in files if item.get("path")] | |
| total_known = sum(int(item["size"]) for item in files if isinstance(item.get("size"), int)) | |
| lower_paths = [path.lower() for path in paths] | |
| suspicious_names = [ | |
| path | |
| for path in paths | |
| if any(re.search(pattern, path.lower()) for pattern in SENSITIVE_FILENAME_PATTERNS) | |
| ] | |
| return { | |
| "files_count": len(paths), | |
| "total_size_known": total_known if total_known > 0 else None, | |
| "has_gguf": any(path.endswith(".gguf") for path in lower_paths), | |
| "has_onnx": any(path.endswith(".onnx") for path in lower_paths), | |
| "has_safetensors": any(path.endswith(".safetensors") for path in lower_paths), | |
| "has_bin": any(path.endswith(".bin") for path in lower_paths), | |
| "suspicious_names": suspicious_names[:30], | |
| } | |
| def warnings_from_meta(meta: dict[str, Any]) -> list[str]: | |
| warnings: list[str] = [] | |
| risk = meta.get("_risk", {}) or {} | |
| if meta.get("Gated") == "Yes" or meta.get("Private") == "Yes": | |
| warnings.append("Repo may require HF_TOKEN because it is private or gated.") | |
| total_size = risk.get("total_size_known") | |
| if isinstance(total_size, int) and total_size > 8 * 1024**3: | |
| warnings.append("Large repo size detected (>8GB). Prefer selective download when possible.") | |
| if risk.get("has_gguf"): | |
| warnings.append( | |
| "GGUF detected. Use a llama.cpp / llama-cpp-python flow instead of generic Transformers." | |
| ) | |
| if risk.get("suspicious_names"): | |
| warnings.append( | |
| "Potentially sensitive filenames detected. This is filename-based only; review before use." | |
| ) | |
| if meta.get("Pipeline") == "text-generation": | |
| warnings.append("Text-generation models can be slow without adequate GPU/VRAM.") | |
| return warnings | |
| def to_files_table(files: list[dict[str, Any]], limit: int = 250) -> list[list[Any]]: | |
| return [ | |
| [item.get("path", ""), human_bytes(item.get("size")) if isinstance(item.get("size"), int) else "N/A"] | |
| for item in (files or [])[:limit] | |
| ] | |
| def filter_files(files: list[dict[str, Any]], query: str, limit: int = 250) -> list[list[Any]]: | |
| query = (query or "").strip().lower() | |
| if not query: | |
| return to_files_table(files, limit=limit) | |
| rows: list[list[Any]] = [] | |
| for item in files or []: | |
| path = item.get("path") or "" | |
| if query in path.lower(): | |
| size = human_bytes(item.get("size")) if isinstance(item.get("size"), int) else "N/A" | |
| rows.append([path, size]) | |
| if len(rows) >= limit: | |
| break | |
| return rows | |
| def first_file_with_ext(files: list[dict[str, Any]], extension: str) -> str | None: | |
| extension = (extension or "").lower() | |
| for item in files or []: | |
| path = item.get("path") or "" | |
| if path.lower().endswith(extension): | |
| return path | |
| return None | |
| def compute_requirements(repo_type: str, meta: dict[str, Any]) -> list[str]: | |
| repo_type = norm_type(repo_type) | |
| pipeline_tag = (meta or {}).get("_pipeline_tag", "N/A") | |
| sdk = (meta or {}).get("_sdk", "N/A") | |
| has_gguf = bool((meta or {}).get("_risk", {}).get("has_gguf") or (meta or {}).get("_has_gguf")) | |
| if repo_type == "dataset": | |
| return ["datasets", "huggingface_hub"] | |
| if repo_type == "space": | |
| if sdk == "streamlit": | |
| return ["streamlit", "huggingface_hub", "requests"] | |
| if sdk == "gradio": | |
| return ["gradio", "huggingface_hub", "requests"] | |
| return ["huggingface_hub", "requests"] | |
| if has_gguf: | |
| return ["huggingface_hub", "llama-cpp-python"] | |
| if pipeline_tag == "text-generation": | |
| return ["transformers", "huggingface_hub", "torch", "accelerate"] | |
| if pipeline_tag in {"image-classification", "image-to-text", "image-segmentation", "object-detection"}: | |
| return ["transformers", "huggingface_hub", "torch", "pillow", "requests"] | |
| return ["transformers", "huggingface_hub", "torch"] | |
| def generate_install(repo_type: str, meta: dict[str, Any]) -> str: | |
| return "python -m pip install " + " ".join(compute_requirements(repo_type, meta)) | |
| def generate_quickstart(repo_type: str, repo_id: str, meta: dict[str, Any]) -> str: | |
| repo_type = norm_type(repo_type) | |
| repo_id = norm_id(repo_id) | |
| pipeline_tag = (meta or {}).get("_pipeline_tag", "N/A") | |
| sdk = (meta or {}).get("_sdk", "N/A") | |
| risk = (meta or {}).get("_risk", {}) or {} | |
| has_gguf = bool(risk.get("has_gguf") or (meta or {}).get("_has_gguf")) | |
| files = (meta or {}).get("_files", []) or [] | |
| repo_id_literal = py_literal(repo_id) | |
| if repo_type == "dataset": | |
| return textwrap.dedent( | |
| f""" | |
| from datasets import load_dataset | |
| ds = load_dataset({repo_id_literal}) | |
| print(ds) | |
| """ | |
| ).strip() | |
| if repo_type == "space": | |
| repo_dir_literal = py_literal(repo_id.split("/")[-1]) | |
| space_url_literal = py_literal(hf_url("space", repo_id)) | |
| if sdk == "streamlit": | |
| return textwrap.dedent( | |
| f""" | |
| import os | |
| import subprocess | |
| subprocess.check_call(["git", "clone", {space_url_literal}]) | |
| os.chdir({repo_dir_literal}) | |
| subprocess.check_call(["python", "-m", "pip", "install", "-r", "requirements.txt"]) | |
| subprocess.check_call(["streamlit", "run", "app.py"]) | |
| """ | |
| ).strip() | |
| return textwrap.dedent( | |
| f""" | |
| import os | |
| import subprocess | |
| subprocess.check_call(["git", "clone", {space_url_literal}]) | |
| os.chdir({repo_dir_literal}) | |
| subprocess.check_call(["python", "-m", "pip", "install", "-r", "requirements.txt"]) | |
| subprocess.check_call(["python", "app.py"]) | |
| """ | |
| ).strip() | |
| if has_gguf: | |
| gguf_name = first_file_with_ext(files, ".gguf") or "MODEL.gguf" | |
| gguf_name_literal = py_literal(gguf_name) | |
| return textwrap.dedent( | |
| f""" | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| gguf_path = hf_hub_download(repo_id={repo_id_literal}, filename={gguf_name_literal}) | |
| llm = Llama(model_path=gguf_path, n_ctx=4096) | |
| out = llm("Q: Hello!\\nA:", max_tokens=128) | |
| print(out["choices"][0]["text"]) | |
| """ | |
| ).strip() | |
| if pipeline_tag == "text-generation": | |
| return textwrap.dedent( | |
| f""" | |
| from transformers import pipeline | |
| pipe = pipeline( | |
| "text-generation", | |
| model={repo_id_literal}, | |
| device_map="auto", | |
| ) | |
| out = pipe("Hello, Hugging Face!", max_new_tokens=64) | |
| print(out[0]["generated_text"]) | |
| """ | |
| ).strip() | |
| if pipeline_tag == "text-classification": | |
| return textwrap.dedent( | |
| f""" | |
| from transformers import pipeline | |
| clf = pipeline("text-classification", model={repo_id_literal}) | |
| print(clf("I love this project.")) | |
| """ | |
| ).strip() | |
| if pipeline_tag == "image-classification": | |
| return textwrap.dedent( | |
| f""" | |
| from io import BytesIO | |
| import requests | |
| from PIL import Image | |
| from transformers import pipeline | |
| image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" | |
| image = Image.open(BytesIO(requests.get(image_url, timeout=20).content)) | |
| pipe = pipeline("image-classification", model={repo_id_literal}) | |
| print(pipe(image)) | |
| """ | |
| ).strip() | |
| return textwrap.dedent( | |
| f""" | |
| from transformers import AutoModel, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained({repo_id_literal}) | |
| model = AutoModel.from_pretrained({repo_id_literal}) | |
| print(type(tokenizer)) | |
| print(type(model)) | |
| """ | |
| ).strip() | |
| def generate_snapshot_download(repo_type: str, repo_id: str) -> str: | |
| repo_type = norm_type(repo_type) | |
| repo_id = norm_id(repo_id) | |
| local_dir = f"./{repo_id.split('/')[-1]}" | |
| lines = [ | |
| "from huggingface_hub import snapshot_download", | |
| "", | |
| "path = snapshot_download(", | |
| f" repo_id={py_literal(repo_id)},", | |
| ] | |
| if repo_type != "model": | |
| lines.append(f" repo_type={py_literal(repo_type)},") | |
| lines.extend( | |
| [ | |
| f" local_dir={py_literal(local_dir)},", | |
| ")", | |
| 'print(f"Downloaded to: {path}")', | |
| ] | |
| ) | |
| return "\n".join(lines) | |
| def generate_cli_download(repo_type: str, repo_id: str) -> str: | |
| repo_type = norm_type(repo_type) | |
| repo_id = norm_id(repo_id) | |
| return f'hf download {repo_id} --repo-type {repo_type} --local-dir "./downloaded_repo"' | |
| def generate_badge(repo_type: str, repo_id: str) -> str: | |
| repo_type = norm_type(repo_type) | |
| repo_id = norm_id(repo_id) | |
| url = hf_url(repo_type, repo_id) | |
| encoded = repo_id.replace("/", "%2F") | |
| return ( | |
| f"[]({url})" | |
| ) | |
| def token_allowed_for_repo(repo_id: str) -> bool: | |
| """Return whether the configured server token may be used for this repo. | |
| Server-token mode is intentionally fail-closed: enabling ALLOW_SERVER_TOKEN | |
| is not enough on its own. TOKEN_ALLOWED_OWNERS must also scope the token to | |
| trusted Hugging Face owners. | |
| """ | |
| owners = os.getenv("TOKEN_ALLOWED_OWNERS", "").strip() | |
| if not owners: | |
| return False | |
| allowed_owners = {owner.strip().lower() for owner in owners.split(",") if owner.strip()} | |
| owner = (norm_id(repo_id).split("/")[0] if "/" in norm_id(repo_id) else "").lower() | |
| return bool(owner) and owner in allowed_owners | |
| def get_effective_token(repo_id: str) -> str | None: | |
| if os.getenv("ALLOW_SERVER_TOKEN", "").strip() != "1": | |
| return None | |
| token = (os.getenv("HF_TOKEN") or "").strip() | |
| if not token: | |
| return None | |
| return token if token_allowed_for_repo(repo_id) else None | |
| def fetch_repo_info( | |
| repo_type: str, repo_id: str, token: str | None | |
| ) -> tuple[bool, dict[str, Any] | None, str | None]: | |
| api = HfApi() | |
| repo_type = norm_type(repo_type) | |
| repo_id = norm_id(repo_id) | |
| token = (token or "").strip() or None | |
| if not repo_id: | |
| return False, None, "Empty Repo ID." | |
| if not is_valid_repo_id(repo_id): | |
| return False, None, "Invalid Repo ID. Expected: repo-name or owner/name" | |
| try: | |
| if repo_type == "dataset": | |
| info = call_with_supported_kwargs(api.dataset_info, repo_id, token=token, files_metadata=True) | |
| elif repo_type == "space": | |
| info = call_with_supported_kwargs(api.space_info, repo_id, token=token, files_metadata=True) | |
| else: | |
| info = call_with_supported_kwargs(api.model_info, repo_id, token=token, files_metadata=True) | |
| card = getattr(info, "cardData", None) or {} | |
| license_name = card.get("license") or getattr(info, "license", None) or "N/A" | |
| gated = getattr(info, "gated", None) | |
| private = getattr(info, "private", None) | |
| pipeline = getattr(info, "pipeline_tag", None) or "N/A" | |
| sdk = getattr(info, "sdk", None) or "N/A" | |
| files = extract_file_entries(info) | |
| if not files: | |
| try: | |
| names = api.list_repo_files(repo_id=repo_id, repo_type=repo_type, token=token) | |
| files = [{"path": name, "size": None} for name in (names or [])] | |
| except Exception: | |
| files = [] | |
| risk = files_risk_report(files) | |
| total_size = human_bytes(risk.get("total_size_known")) if risk.get("total_size_known") else "N/A" | |
| preview: dict[str, Any] = { | |
| "Repo ID": getattr(info, "id", repo_id), | |
| "Type": repo_type, | |
| "Author": getattr(info, "author", None) or getattr(info, "owner", None) or "N/A", | |
| "Likes": getattr(info, "likes", 0) or 0, | |
| "Downloads": getattr(info, "downloads", 0) or 0, | |
| "Last Modified": safe_str(getattr(info, "lastModified", "N/A"), 200), | |
| "License": str(license_name) if license_name else "N/A", | |
| "Pipeline": str(pipeline) if pipeline else "N/A", | |
| "Gated": "Yes" if gated is True else ("No" if gated is False else "N/A"), | |
| "Private": "Yes" if private is True else ("No" if private is False else "N/A"), | |
| "Total Size": total_size, | |
| "Files Count": risk.get("files_count", 0), | |
| } | |
| if repo_type == "space": | |
| preview["SDK"] = sdk or "N/A" | |
| hardware = getattr(info, "hardware", None) | |
| if hardware: | |
| preview["Hardware"] = safe_str(hardware, 200) | |
| preview.update( | |
| { | |
| "_pipeline_tag": pipeline or "N/A", | |
| "_sdk": sdk or "N/A", | |
| "_files": files, | |
| "_risk": risk, | |
| "_has_gguf": bool(risk.get("has_gguf")), | |
| "_rid": repo_id, | |
| "_rt": repo_type, | |
| } | |
| ) | |
| return True, preview, None | |
| except HfHubHTTPError as error: | |
| return False, None, safe_hf_error(error) | |
| except Exception as error: | |
| return False, None, f"Unexpected Error: {safe_str(error, 500)}" | |
| _PUBLIC_CACHE: dict[tuple[str, str], tuple[bool, dict[str, Any] | None, str | None]] = {} | |
| def cached_public(repo_type: str, repo_id: str) -> tuple[bool, dict[str, Any] | None, str | None]: | |
| """Fetch public repo metadata and cache successful responses only. | |
| Transient network errors should not become sticky until process restart. | |
| """ | |
| key = (norm_type(repo_type), norm_id(repo_id)) | |
| if key in _PUBLIC_CACHE: | |
| return _PUBLIC_CACHE[key] | |
| result = fetch_repo_info(key[0], key[1], token=None) | |
| if result[0]: | |
| _PUBLIC_CACHE[key] = result | |
| return result | |
| def build_export_files(state: dict[str, Any]) -> dict[str, str]: | |
| if not isinstance(state, dict) or not state.get("Repo ID"): | |
| raise ValueError("No repo loaded yet.") | |
| repo_type = norm_type(state.get("Type", "model")) | |
| repo_id = norm_id(state.get("Repo ID", "")) or norm_id(state.get("_rid", "")) | |
| if not is_valid_repo_id(repo_id): | |
| raise ValueError("Invalid Repo ID. Expected: repo-name or owner/name") | |
| install = generate_install(repo_type, state) | |
| quickstart = generate_quickstart(repo_type, repo_id, state) | |
| snapshot = generate_snapshot_download(repo_type, repo_id) | |
| requirements = compute_requirements(repo_type, state) | |
| readme = textwrap.dedent( | |
| f""" | |
| # QuickStart — {repo_id} | |
| Minimal first-run scaffold generated for `{repo_id}`. | |
| ## Setup | |
| ```bash | |
| python -m venv .venv | |
| python -m pip install -r requirements.txt | |
| ``` | |
| ## Run | |
| ```bash | |
| python run.py | |
| ``` | |
| ## Download full snapshot | |
| ```bash | |
| python download.py | |
| ``` | |
| ## Reference install | |
| ```bash | |
| {install} | |
| ``` | |
| """ | |
| ).strip() | |
| run_py = "\n".join( | |
| [ | |
| "def main():", | |
| ' print("Install/reference command:")', | |
| f" print({install!r})", | |
| "", | |
| textwrap.indent(quickstart, " "), | |
| "", | |
| "", | |
| 'if __name__ == "__main__":', | |
| " main()", | |
| ] | |
| ) | |
| download_py = snapshot.strip() | |
| return { | |
| "README.md": readme + "\n", | |
| "requirements.txt": "\n".join(requirements) + "\n", | |
| ".env.example": "HF_TOKEN=\n", | |
| "run.py": run_py + "\n", | |
| "download.py": download_py + "\n", | |
| } | |
| def build_quickstart_zip(state: dict[str, Any]) -> tuple[str | None, str]: | |
| try: | |
| files = build_export_files(state) | |
| except ValueError as error: | |
| return None, str(error) | |
| repo_id = norm_id(state.get("Repo ID", "")) or norm_id(state.get("_rid", "repo")) | |
| temp_dir = Path(tempfile.mkdtemp(prefix="quickstart_")) | |
| zip_path = temp_dir / f"{repo_id.replace('/', '__')}_quickstart.zip" | |
| project_dir = temp_dir / "project" | |
| project_dir.mkdir(parents=True, exist_ok=True) | |
| for name, content in files.items(): | |
| path = project_dir / name | |
| path.write_text(content, encoding="utf-8") | |
| with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as archive: | |
| for name in files: | |
| archive.write(project_dir / name, arcname=name) | |
| return str(zip_path), "Zip built. Download it, unzip it, then run: python run.py" | |