Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import base64 | |
| import hashlib | |
| import io | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| # ---- Model (SigLIP 768d) --------------------------------------------------- | |
| SIGLIP_MODEL_ID = "google/siglip-base-patch16-224" | |
| class _Embedder: | |
| processor: Any | |
| model: Any | |
| _EMBEDDER: _Embedder | None = None | |
| def _get_embedder() -> _Embedder: | |
| global _EMBEDDER | |
| if _EMBEDDER is not None: | |
| return _EMBEDDER | |
| import torch | |
| from transformers import AutoProcessor, AutoModel | |
| processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID) | |
| model = AutoModel.from_pretrained(SIGLIP_MODEL_ID) | |
| model.eval() | |
| torch.set_grad_enabled(False) | |
| _EMBEDDER = _Embedder(processor=processor, model=model) | |
| return _EMBEDDER | |
| def _to_pil(x: Any) -> Image.Image: | |
| if isinstance(x, Image.Image): | |
| return x | |
| if isinstance(x, dict) and isinstance(x.get("path"), str): | |
| return Image.open(x["path"]).convert("RGBA") | |
| if isinstance(x, str): | |
| return Image.open(x).convert("RGBA") | |
| raise TypeError(f"Unsupported image input: {type(x).__name__}") | |
| def _sha256_bytes(b: bytes) -> str: | |
| return hashlib.sha256(b).hexdigest() | |
| def _sha256_image(img: Image.Image) -> str: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return _sha256_bytes(buf.getvalue()) | |
| def _l2_normalize(v: np.ndarray) -> np.ndarray: | |
| n = np.linalg.norm(v, axis=-1, keepdims=True) | |
| n = np.maximum(n, 1e-12) | |
| return v / n | |
| def _embed_pils(pils: list[Image.Image]) -> list[dict[str, Any]]: | |
| import torch | |
| emb = _get_embedder() | |
| inputs = emb.processor(images=[p.convert("RGB") for p in pils], return_tensors="pt") | |
| with torch.no_grad(): | |
| # SigLIP-style models expose get_image_features on the multi-modal wrapper. | |
| if hasattr(emb.model, "get_image_features"): | |
| feats = emb.model.get_image_features(**inputs) | |
| else: | |
| out = emb.model(**inputs) | |
| feats = getattr(out, "pooler_output", None) or out.last_hidden_state[:, 0, :] | |
| feats = feats.detach().cpu().numpy().astype("float32") | |
| feats = _l2_normalize(feats) | |
| out: list[dict[str, Any]] = [] | |
| for p, vec in zip(pils, feats): | |
| out.append( | |
| { | |
| "dims": int(vec.shape[0]), | |
| "norm": "l2", | |
| "model_id": SIGLIP_MODEL_ID, | |
| "sha256": _sha256_image(p), | |
| "vector": vec.tolist(), | |
| } | |
| ) | |
| return out | |
| # ---- Metrics / Heuristics --------------------------------------------------- | |
| def _dhash(img: Image.Image, size: int = 8) -> str: | |
| g = img.convert("L").resize((size + 1, size), Image.BILINEAR) | |
| a = np.asarray(g, dtype=np.int16) | |
| diff = a[:, 1:] > a[:, :-1] | |
| bits = "".join("1" if x else "0" for x in diff.flatten().tolist()) | |
| return hex(int(bits, 2))[2:].rjust(size * size // 4, "0") | |
| def _laplacian_var(img: Image.Image) -> float: | |
| g = img.convert("L") | |
| a = np.asarray(g, dtype=np.float32) | |
| k = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32) | |
| # simple conv2d valid | |
| h, w = a.shape | |
| if h < 3 or w < 3: | |
| return 0.0 | |
| out = ( | |
| a[1 : h - 1, 0 : w - 2] * k[1, 0] | |
| + a[0 : h - 2, 1 : w - 1] * k[0, 1] | |
| + a[1 : h - 1, 1 : w - 1] * k[1, 1] | |
| + a[2:h, 1 : w - 1] * k[2, 1] | |
| + a[1 : h - 1, 2:w] * k[1, 2] | |
| ) | |
| return float(np.var(out)) | |
| def image_metrics(image: Any) -> str: | |
| img = _to_pil(image) | |
| arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0 | |
| has_alpha = img.mode in ("RGBA", "LA") | |
| alpha_cov = 1.0 | |
| if has_alpha: | |
| a = np.asarray(img.split()[-1], dtype=np.float32) / 255.0 | |
| alpha_cov = float(np.mean(a > 0.05)) | |
| metrics = { | |
| "width": img.width, | |
| "height": img.height, | |
| "blur_laplacian_var": _laplacian_var(img), | |
| "contrast_std": float(np.std(arr)), | |
| "mean_brightness": float(np.mean(arr)), | |
| "dhash": _dhash(img), | |
| "has_alpha": bool(has_alpha), | |
| "alpha_coverage": alpha_cov, | |
| "sha256": _sha256_image(img), | |
| } | |
| return json.dumps(metrics) | |
| # ---- VLM prep (OpenAI image_url data URL) ---------------------------------- | |
| def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: | |
| max_side = int(max_side) | |
| if max_side <= 0: | |
| return img | |
| w, h = img.size | |
| m = max(w, h) | |
| if m <= max_side: | |
| return img | |
| scale = max_side / float(m) | |
| nw = max(1, int(round(w * scale))) | |
| nh = max(1, int(round(h * scale))) | |
| return img.resize((nw, nh), Image.LANCZOS) | |
| def prepare_for_openai_vlm(image: Any, max_side: int = 768, fmt: str = "webp", quality: int = 85) -> str: | |
| img = _to_pil(image) | |
| img = _resize_max_side(img, max_side=max_side) | |
| fmt = (fmt or "webp").lower() | |
| quality = int(quality) | |
| buf = io.BytesIO() | |
| mime = "image/webp" | |
| if fmt == "jpeg" or fmt == "jpg": | |
| mime = "image/jpeg" | |
| img.convert("RGB").save(buf, format="JPEG", quality=quality, optimize=True) | |
| elif fmt == "png": | |
| mime = "image/png" | |
| img.save(buf, format="PNG", optimize=True) | |
| else: | |
| mime = "image/webp" | |
| img.convert("RGB").save(buf, format="WEBP", quality=quality, method=6) | |
| b = buf.getvalue() | |
| url = f"data:{mime};base64," + base64.b64encode(b).decode("ascii") | |
| out = { | |
| "url": url, | |
| "mime": mime, | |
| "width": img.width, | |
| "height": img.height, | |
| "sha256": _sha256_bytes(b), | |
| } | |
| return json.dumps(out) | |
| def prepare_for_openai_vlm_batch(images: list[Any], max_side: int = 768, fmt: str = "webp", quality: int = 85) -> str: | |
| out = [] | |
| for x in images or []: | |
| out.append(json.loads(prepare_for_openai_vlm(x, max_side=max_side, fmt=fmt, quality=quality))) | |
| return json.dumps(out) | |
| # ---- Background removal + alpha trim ---------------------------------------- | |
| def bg_remove(image: Any) -> tuple[str, str]: | |
| from rembg import remove | |
| img = _to_pil(image).convert("RGBA") | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| out_bytes = remove(buf.getvalue()) | |
| # Write to a temp file Gradio can serve | |
| out_path = "bg_removed.png" | |
| with open(out_path, "wb") as f: | |
| f.write(out_bytes) | |
| meta = {"method": "rembg", "sha256_in": _sha256_image(img), "sha256_out": _sha256_bytes(out_bytes)} | |
| return out_path, json.dumps(meta) | |
| def trim_alpha(image: Any) -> tuple[str, str]: | |
| img = _to_pil(image).convert("RGBA") | |
| a = np.asarray(img.split()[-1], dtype=np.uint8) | |
| ys, xs = np.where(a > 0) | |
| if len(xs) == 0 or len(ys) == 0: | |
| out_path = "trimmed.png" | |
| img.save(out_path, format="PNG") | |
| meta = {"bbox": [0, 0, img.width, img.height], "orig_size": [img.width, img.height]} | |
| return out_path, json.dumps(meta) | |
| x0, x1 = int(xs.min()), int(xs.max()) | |
| y0, y1 = int(ys.min()), int(ys.max()) | |
| # inclusive -> size | |
| w = x1 - x0 + 1 | |
| h = y1 - y0 + 1 | |
| cropped = img.crop((x0, y0, x0 + w, y0 + h)) | |
| out_path = "trimmed.png" | |
| cropped.save(out_path, format="PNG") | |
| meta = {"bbox": [x0, y0, w, h], "orig_size": [img.width, img.height]} | |
| return out_path, json.dumps(meta) | |
| # ---- Spritesheet packing ---------------------------------------------------- | |
| def pack_spritesheet(images: list[Any], names_json: str) -> tuple[str, str]: | |
| names = [] | |
| try: | |
| names = json.loads(names_json or "[]") | |
| except Exception: | |
| names = [] | |
| if not isinstance(names, list): | |
| names = [] | |
| pils = [_to_pil(x).convert("RGBA") for x in (images or [])] | |
| if not pils: | |
| return "", json.dumps({"error": "no_images"}) | |
| # Simple grid packer: fixed columns, max cell size per image. | |
| cols = min(4, len(pils)) | |
| rows = int(np.ceil(len(pils) / cols)) | |
| cell_w = max(p.width for p in pils) | |
| cell_h = max(p.height for p in pils) | |
| sheet = Image.new("RGBA", (cell_w * cols, cell_h * rows), (0, 0, 0, 0)) | |
| mapping: dict[str, Any] = {"cell": [cell_w, cell_h], "items": {}} | |
| for i, p in enumerate(pils): | |
| r = i // cols | |
| c = i % cols | |
| x = c * cell_w | |
| y = r * cell_h | |
| sheet.alpha_composite(p, (x, y)) | |
| key = str(names[i]) if i < len(names) else f"item_{i}" | |
| mapping["items"][key] = {"x": x, "y": y, "w": p.width, "h": p.height} | |
| out_path = "spritesheet.png" | |
| sheet.save(out_path, format="PNG") | |
| return out_path, json.dumps(mapping) | |
| # ---- Public endpoints ------------------------------------------------------- | |
| def health() -> str: | |
| return json.dumps({"ok": True, "embed_dims": 768, "model_id": SIGLIP_MODEL_ID}) | |
| def embed_images_batch(images: list[Any]) -> str: | |
| pils = [_to_pil(x) for x in (images or [])] | |
| out = _embed_pils(pils) | |
| return json.dumps(out) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Processing Service") | |
| with gr.Tab("API"): | |
| inp = gr.File(label="Image", file_types=["image"]) | |
| max_side = gr.Slider(128, 2048, value=768, step=64, label="max_side (VLM prep)") | |
| fmt = gr.Dropdown(["webp", "jpeg", "png"], value="webp", label="format") | |
| quality = gr.Slider(10, 100, value=85, step=1, label="quality") | |
| out_json = gr.Code(language="json", label="Output JSON") | |
| out_file = gr.File(label="Output File") | |
| gr.Button("Health").click(health, outputs=out_json, api_name="/health") | |
| gr.Button("Prepare for OpenAI VLM").click( | |
| prepare_for_openai_vlm, inputs=[inp, max_side, fmt, quality], outputs=out_json, api_name="/prepare_for_openai_vlm" | |
| ) | |
| gr.Button("Metrics").click(image_metrics, inputs=inp, outputs=out_json, api_name="/image_metrics") | |
| gr.Button("BG Remove").click(bg_remove, inputs=inp, outputs=[out_file, out_json], api_name="/bg_remove") | |
| gr.Button("Trim Alpha").click(trim_alpha, inputs=inp, outputs=[out_file, out_json], api_name="/trim_alpha") | |
| # Batch endpoints (API-only; UI is minimal) | |
| batch_inp = gr.Files(label="Images (batch)", file_types=["image"]) | |
| batch_out = gr.Code(language="json", label="Batch JSON") | |
| gr.Button("Prepare VLM Batch").click( | |
| prepare_for_openai_vlm_batch, inputs=[batch_inp, max_side, fmt, quality], outputs=batch_out, api_name="/prepare_for_openai_vlm_batch" | |
| ) | |
| gr.Button("Embed Batch").click(embed_images_batch, inputs=batch_inp, outputs=batch_out, api_name="/embed_images_batch") | |
| # Spritesheet pack | |
| names = gr.Textbox(label="Names JSON", value='["neutral","happy"]') | |
| sheet_file = gr.File(label="Spritesheet PNG") | |
| sheet_map = gr.Code(language="json", label="Spritesheet Map") | |
| gr.Button("Pack Spritesheet").click(pack_spritesheet, inputs=[batch_inp, names], outputs=[sheet_file, sheet_map], api_name="/pack_spritesheet") | |
| if __name__ == "__main__": | |
| # HF Spaces runs behind a proxy; bind to 0.0.0.0 and the platform port. | |
| port = int(os.environ.get("PORT", "7860")) | |
| demo.queue(default_concurrency_limit=2, max_size=64).launch(server_name="0.0.0.0", server_port=port) | |