| | """ |
| | handler.py — Hugging Face Inference Endpoint custom handler |
| | Outputs: GIF, WebM, ZIP(frames) |
| | |
| | This version maintains UNIVERSAL compatibility: |
| | - Defensive argument guessing (num_frames vs video_length) |
| | - Robust output shape parsing (TBL, BCTHW, etc.) |
| | - Adds Support for Image-to-Video via `image` input (base64) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | |
| | |
| | |
| |
|
| | import os |
| | import sys |
| | import types |
| |
|
| | os.environ.setdefault("IMAGEIO_NO_INTERNET", "1") |
| | os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") |
| |
|
| | def _patch_hf_toolkit_ffmpeg_plugin() -> dict: |
| | """ |
| | Best-effort patching so huggingface_inference_toolkit won't crash if something |
| | tries to resolve plugin name "ffmpeg". |
| | """ |
| | diag = {"patched": False, "details": []} |
| |
|
| | try: |
| | import huggingface_inference_toolkit as hfit |
| | diag["details"].append("imported huggingface_inference_toolkit") |
| | except Exception as e: |
| | diag["details"].append(f"huggingface_inference_toolkit not importable: {e}") |
| | return diag |
| |
|
| | |
| | registry_candidates = [] |
| | for name in dir(hfit): |
| | if any(k in name.lower() for k in ("plugin", "registry", "registries", "plugins")): |
| | try: |
| | obj = getattr(hfit, name) |
| | if isinstance(obj, dict): |
| | registry_candidates.append((name, obj)) |
| | except Exception: |
| | pass |
| |
|
| | for name, reg in registry_candidates: |
| | if "ffmpeg" not in reg: |
| | try: |
| | reg["ffmpeg"] = object() |
| | diag["patched"] = True |
| | diag["details"].append(f"added ffmpeg to registry dict: hfit.{name}") |
| | except Exception as e: |
| | diag["details"].append(f"failed adding to hfit.{name}: {e}") |
| |
|
| | |
| | fn_names = [n for n in dir(hfit) if any(k in n.lower() for k in ("get_plugin", "resolve_plugin", "load_plugin"))] |
| | for fn_name in fn_names: |
| | try: |
| | fn = getattr(hfit, fn_name) |
| | if not callable(fn): |
| | continue |
| | if getattr(fn, "__ffmpeg_patched__", False): |
| | continue |
| |
|
| | def _wrap(original): |
| | def wrapped(*args, **kwargs): |
| | if args and isinstance(args[0], str) and args[0].lower() == "ffmpeg": |
| | return object() |
| | return original(*args, **kwargs) |
| | wrapped.__ffmpeg_patched__ = True |
| | return wrapped |
| |
|
| | setattr(hfit, fn_name, _wrap(fn)) |
| | diag["patched"] = True |
| | diag["details"].append(f"wrapped callable: hfit.{fn_name} to accept ffmpeg") |
| | except Exception as e: |
| | diag["details"].append(f"failed wrapping {fn_name}: {e}") |
| |
|
| | |
| | dummy_mod_name = "huggingface_inference_toolkit.plugins.ffmpeg" |
| | if dummy_mod_name not in sys.modules: |
| | dummy = types.ModuleType(dummy_mod_name) |
| | dummy.__dict__["__doc__"] = "Dummy ffmpeg plugin injected by handler.py to avoid registry errors." |
| | sys.modules[dummy_mod_name] = dummy |
| | diag["details"].append(f"injected sys.modules['{dummy_mod_name}'] (dummy module)") |
| |
|
| | return diag |
| |
|
| | HF_TOOLKIT_PATCH_DIAG = _patch_hf_toolkit_ffmpeg_plugin() |
| |
|
| | |
| | |
| | |
| |
|
| | import base64 |
| | import io |
| | import time |
| | import tempfile |
| | import zipfile |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | import numpy as np |
| | from PIL import Image |
| |
|
| | import imageio |
| |
|
| | try: |
| | import imageio_ffmpeg |
| | _FFMPEG_EXE = imageio_ffmpeg.get_ffmpeg_exe() |
| | |
| | os.environ["IMAGEIO_FFMPEG_EXE"] = _FFMPEG_EXE |
| | except Exception: |
| | _FFMPEG_EXE = "" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _now_ms() -> int: |
| | return int(time.time() * 1000) |
| |
|
| |
|
| | def _b64(data: bytes) -> str: |
| | return base64.b64encode(data).decode("utf-8") |
| |
|
| |
|
| | def _b64_to_pil(b64_str: str) -> Image.Image: |
| | if "," in b64_str: |
| | b64_str = b64_str.split(",")[1] |
| | data = base64.b64decode(b64_str) |
| | return Image.open(io.BytesIO(data)).convert("RGB") |
| |
|
| |
|
| | def _clamp_uint8_frame(frame: np.ndarray) -> np.ndarray: |
| | """ |
| | Normalize a frame into uint8 RGB (H,W,3). |
| | Accepts float [0,1] or [-1,1], uint8, grayscale, RGBA, or CHW. |
| | """ |
| | if not isinstance(frame, np.ndarray): |
| | frame = np.array(frame) |
| |
|
| | |
| | if frame.ndim == 4 and frame.shape[0] == 1: |
| | frame = frame[0] |
| |
|
| | |
| | if frame.ndim == 2: |
| | frame = np.stack([frame, frame, frame], axis=-1) |
| |
|
| | if frame.ndim != 3: |
| | raise ValueError(f"Frame must be 2D or 3D array; got shape {frame.shape}") |
| |
|
| | |
| | if frame.shape[-1] == 4: |
| | frame = frame[..., :3] |
| | elif frame.shape[-1] == 1: |
| | frame = np.repeat(frame, 3, axis=-1) |
| | elif frame.shape[-1] != 3: |
| | |
| | if frame.shape[0] == 3: |
| | frame = np.transpose(frame, (1, 2, 0)) |
| | else: |
| | raise ValueError(f"Unsupported channels shape: {frame.shape}") |
| |
|
| | if frame.dtype == np.uint8: |
| | return frame |
| |
|
| | f = frame.astype(np.float32) |
| | if f.min() < 0.0: |
| | f = (f + 1.0) / 2.0 |
| | f = np.clip(f, 0.0, 1.0) |
| | return (f * 255.0).round().astype(np.uint8) |
| |
|
| |
|
| | def _encode_gif(frames: List[np.ndarray], fps: int) -> bytes: |
| | if not frames: |
| | raise ValueError("No frames to encode GIF.") |
| | pil_frames = [Image.fromarray(_clamp_uint8_frame(f)) for f in frames] |
| | duration_ms = int(1000 / max(1, fps)) |
| | buf = io.BytesIO() |
| | pil_frames[0].save( |
| | buf, |
| | format="GIF", |
| | save_all=True, |
| | append_images=pil_frames[1:], |
| | duration=duration_ms, |
| | loop=0, |
| | optimize=False, |
| | disposal=2, |
| | ) |
| | return buf.getvalue() |
| |
|
| |
|
| | def _encode_webm(frames: List[np.ndarray], fps: int, quality: str = "good") -> bytes: |
| | """ |
| | Encode WebM (VP9) via imageio. |
| | """ |
| | if not frames: |
| | raise ValueError("No frames to encode WebM.") |
| |
|
| | quality = (quality or "good").lower() |
| | if quality == "fast": |
| | crf = 42 |
| | preset = "veryfast" |
| | elif quality == "best": |
| | crf = 28 |
| | preset = "slow" |
| | else: |
| | crf = 34 |
| | preset = "medium" |
| |
|
| | with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp: |
| | out_path = tmp.name |
| |
|
| | try: |
| | writer = imageio.get_writer( |
| | out_path, |
| | fps=max(1, fps), |
| | format="FFMPEG", |
| | codec="libvpx-vp9", |
| | ffmpeg_params=[ |
| | "-pix_fmt", "yuv420p", |
| | "-crf", str(crf), |
| | "-b:v", "0", |
| | "-preset", preset, |
| | ], |
| | ) |
| | try: |
| | for f in frames: |
| | writer.append_data(_clamp_uint8_frame(f)) |
| | finally: |
| | writer.close() |
| |
|
| | with open(out_path, "rb") as f: |
| | return f.read() |
| | finally: |
| | try: |
| | os.remove(out_path) |
| | except Exception: |
| | pass |
| |
|
| |
|
| | def _encode_zip_frames(frames: List[np.ndarray]) -> bytes: |
| | if not frames: |
| | raise ValueError("No frames to ZIP.") |
| |
|
| | buf = io.BytesIO() |
| | with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=6) as zf: |
| | for i, f in enumerate(frames): |
| | arr = _clamp_uint8_frame(f) |
| | im = Image.fromarray(arr) |
| | frame_buf = io.BytesIO() |
| | im.save(frame_buf, format="PNG", optimize=True) |
| | zf.writestr(f"frame_{i:06d}.png", frame_buf.getvalue()) |
| | return buf.getvalue() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class GenParams: |
| | prompt: str |
| | negative_prompt: str |
| | num_frames: int |
| | fps: int |
| | height: int |
| | width: int |
| | seed: Optional[int] |
| | num_inference_steps: int |
| | guidance_scale: float |
| | image_b64: Optional[str] = None |
| |
|
| |
|
| | def _unwrap_inputs(payload: Dict[str, Any]) -> Dict[str, Any]: |
| | if isinstance(payload, dict) and "inputs" in payload and isinstance(payload["inputs"], dict): |
| | return payload["inputs"] |
| | return payload |
| |
|
| |
|
| | def _parse_request(payload: Dict[str, Any]) -> Tuple[GenParams, List[str], bool, Dict[str, Any]]: |
| | data = _unwrap_inputs(payload) |
| |
|
| | prompt = str(data.get("prompt") or data.get("inputs") or "").strip() |
| | if not prompt and "image" not in data: |
| | pass |
| |
|
| | negative_prompt = str(data.get("negative_prompt") or "").strip() |
| |
|
| | num_frames = int(data.get("num_frames") or data.get("frames") or 32) |
| | fps = int(data.get("fps") or 12) |
| | height = int(data.get("height") or 512) |
| | width = int(data.get("width") or 512) |
| | seed = data.get("seed") |
| | seed = int(seed) if seed is not None and str(seed).strip() != "" else None |
| |
|
| | |
| | image_b64 = data.get("image") or data.get("image_base64") |
| |
|
| | num_inference_steps = int(data.get("num_inference_steps") or 30) |
| | guidance_scale = float(data.get("guidance_scale") or 7.5) |
| |
|
| | outputs = data.get("outputs") or ["gif"] |
| | if isinstance(outputs, str): |
| | outputs = [outputs] |
| | outputs = [str(x).lower() for x in outputs] |
| | allowed = {"gif", "webm", "zip"} |
| | outputs = [o for o in outputs if o in allowed] |
| | if not outputs: |
| | outputs = ["gif"] |
| |
|
| | return_base64 = bool(data.get("return_base64", True)) |
| |
|
| | out_cfg = data.get("output_config") or {} |
| | for k in ("gif", "webm", "zip"): |
| | if k in data and isinstance(data[k], dict): |
| | out_cfg[k] = data[k] |
| |
|
| | params = GenParams( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | num_frames=max(1, num_frames), |
| | fps=max(1, fps), |
| | height=max(64, height), |
| | width=max(64, width), |
| | seed=seed, |
| | num_inference_steps=max(1, num_inference_steps), |
| | guidance_scale=guidance_scale, |
| | image_b64=image_b64 |
| | ) |
| | return params, outputs, return_base64, out_cfg |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = "") -> None: |
| | self.repo_path = path or "" |
| | self.pipe = None |
| | self.init_error: Optional[str] = None |
| |
|
| | print("=== CUSTOM handler.py LOADED (Universal Mode) ===", flush=True) |
| | print(f"=== HF toolkit patch diag: {HF_TOOLKIT_PATCH_DIAG} ===", flush=True) |
| | print(f"=== imageio-ffmpeg exe: {_FFMPEG_EXE} ===", flush=True) |
| |
|
| | try: |
| | import torch |
| | from diffusers import DiffusionPipeline, LTXConditionPipeline |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | dtype = torch.float16 if device == "cuda" else torch.float32 |
| |
|
| | subdir = os.getenv("HF_MODEL_SUBDIR", "").strip() |
| | model_path = self.repo_path if not subdir else os.path.join(self.repo_path, subdir) |
| |
|
| | |
| | |
| | try: |
| | print("Attempting to load LTXConditionPipeline...", flush=True) |
| | self.pipe = LTXConditionPipeline.from_pretrained(model_path, torch_dtype=dtype) |
| | except Exception as e: |
| | print(f"LTXConditionPipeline load failed ({e}), falling back to generic DiffusionPipeline...", flush=True) |
| | self.pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) |
| |
|
| | try: |
| | self.pipe.to(device) |
| | except Exception: |
| | pass |
| |
|
| | try: |
| | if hasattr(self.pipe, "enable_vae_slicing"): |
| | self.pipe.enable_vae_slicing() |
| | except Exception: |
| | pass |
| | |
| | |
| | if hasattr(self.pipe, "vae") and hasattr(self.pipe.vae, "enable_tiling"): |
| | self.pipe.vae.enable_tiling() |
| |
|
| | except Exception as e: |
| | self.init_error = str(e) |
| | self.pipe = None |
| | print(f"=== PIPELINE INIT FAILED: {self.init_error} ===", flush=True) |
| |
|
| | def __call__(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
| | t0 = _now_ms() |
| |
|
| | try: |
| | params, outputs, return_b64, out_cfg = _parse_request(payload) |
| |
|
| | frames, gen_diag = self._generate_frames(params) |
| |
|
| | t1 = _now_ms() |
| | result_outputs: Dict[str, Any] = {} |
| |
|
| | |
| | if "gif" in outputs: |
| | gif_fps = int((out_cfg.get("gif") or {}).get("fps") or params.fps) |
| | gif_bytes = _encode_gif(frames, fps=gif_fps) |
| | result_outputs["gif_base64" if return_b64 else "gif_bytes"] = _b64(gif_bytes) if return_b64 else gif_bytes |
| |
|
| | t2 = _now_ms() |
| |
|
| | |
| | if "webm" in outputs: |
| | webm_cfg = out_cfg.get("webm") or {} |
| | webm_fps = int(webm_cfg.get("fps") or params.fps) |
| | webm_quality = str(webm_cfg.get("quality") or "good") |
| | webm_bytes = _encode_webm(frames, fps=webm_fps, quality=webm_quality) |
| | result_outputs["webm_base64" if return_b64 else "webm_bytes"] = _b64(webm_bytes) if return_b64 else webm_bytes |
| |
|
| | t3 = _now_ms() |
| |
|
| | |
| | if "zip" in outputs: |
| | zip_bytes = _encode_zip_frames(frames) |
| | result_outputs["zip_base64" if return_b64 else "zip_bytes"] = _b64(zip_bytes) if return_b64 else zip_bytes |
| |
|
| | t4 = _now_ms() |
| |
|
| | return { |
| | "ok": True, |
| | "outputs": result_outputs, |
| | "diagnostics": { |
| | "timing_ms": { |
| | "total": t4 - t0, |
| | "generate": t1 - t0, |
| | "gif": (t2 - t1) if "gif" in outputs else 0, |
| | "webm": (t3 - t2) if "webm" in outputs else 0, |
| | "zip": (t4 - t3) if "zip" in outputs else 0, |
| | }, |
| | "generator": gen_diag, |
| | "ffmpeg_exe": _FFMPEG_EXE, |
| | "hf_toolkit_patch": HF_TOOLKIT_PATCH_DIAG, |
| | "init_error": self.init_error, |
| | }, |
| | } |
| |
|
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return { |
| | "ok": False, |
| | "error": str(e), |
| | "diagnostics": { |
| | "ffmpeg_exe": _FFMPEG_EXE, |
| | "hf_toolkit_patch": HF_TOOLKIT_PATCH_DIAG, |
| | "init_error": self.init_error, |
| | }, |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | def _generate_frames(self, params: GenParams) -> Tuple[List[np.ndarray], Dict[str, Any]]: |
| | if self.pipe is None: |
| | raise RuntimeError(f"Model pipeline not initialized. Init error: {self.init_error}") |
| |
|
| | |
| | generator = None |
| | try: |
| | import torch |
| | if params.seed is not None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | generator = torch.Generator(device=device).manual_seed(params.seed) |
| | except Exception: |
| | generator = None |
| |
|
| | kwargs: Dict[str, Any] = { |
| | "prompt": params.prompt, |
| | "negative_prompt": params.negative_prompt if params.negative_prompt else None, |
| | "height": params.height, |
| | "width": params.width, |
| | "num_inference_steps": params.num_inference_steps, |
| | "guidance_scale": params.guidance_scale, |
| | |
| | } |
| | |
| | |
| | |
| | |
| | if params.image_b64: |
| | print("Received image input, performing Image-to-Video.", flush=True) |
| | pil_image = _b64_to_pil(params.image_b64) |
| | kwargs["image"] = pil_image |
| |
|
| | |
| | output = None |
| | last_err: Optional[Exception] = None |
| | |
| | |
| | for frame_arg in ("num_frames", "video_length", "num_video_frames"): |
| | try: |
| | call_kwargs = dict(kwargs) |
| | call_kwargs[frame_arg] = params.num_frames |
| | if generator is not None: |
| | call_kwargs["generator"] = generator |
| | |
| | |
| | clean_kwargs = {k: v for k, v in call_kwargs.items() if v is not None} |
| | |
| | output = self.pipe(**clean_kwargs) |
| | break |
| | except Exception as e: |
| | last_err = e |
| | |
| | continue |
| |
|
| | if output is None: |
| | raise RuntimeError(f"Pipeline call failed. Last error: {last_err}") |
| |
|
| | frames: List[np.ndarray] = [] |
| |
|
| | |
| | |
| | |
| | if hasattr(output, "frames") and getattr(output, "frames") is not None: |
| | frames_raw = getattr(output, "frames") |
| | arr = np.array(frames_raw) |
| |
|
| | |
| | if arr.ndim == 4: |
| | frames = [arr[t] for t in range(arr.shape[0])] |
| | |
| | elif arr.ndim == 5: |
| | arr = arr[0] |
| | frames = [arr[t] for t in range(arr.shape[0])] |
| | |
| | elif isinstance(frames_raw, list): |
| | frames = [np.array(f) for f in frames_raw] |
| | else: |
| | raise ValueError(f"Unexpected frames shape: {arr.shape}") |
| |
|
| | |
| | elif hasattr(output, "videos") and getattr(output, "videos") is not None: |
| | vids = getattr(output, "videos") |
| | arr = None |
| | try: |
| | import torch |
| | if isinstance(vids, torch.Tensor): |
| | arr = vids.detach().cpu().numpy() |
| | else: |
| | arr = np.array(vids) |
| | except Exception: |
| | arr = np.array(vids) |
| |
|
| | |
| | if arr.ndim == 5: |
| | arr = arr[0] |
| |
|
| | |
| | if arr.ndim == 4 and arr.shape[1] in (1, 3, 4): |
| | arr = np.transpose(arr, (0, 2, 3, 1)) |
| |
|
| | if arr.ndim != 4: |
| | raise ValueError(f"Unexpected video tensor shape: {arr.shape}") |
| |
|
| | frames = [arr[t] for t in range(arr.shape[0])] |
| |
|
| | |
| | elif hasattr(output, "images") and getattr(output, "images") is not None: |
| | imgs = getattr(output, "images\") |
| | if isinstance(imgs, list): |
| | frames = [np.array(im) for im in imgs] |
| | else: |
| | frames = [np.array(imgs)] |
| | |
| | # 4) dict fallback |
| | elif isinstance(output, dict): |
| | for key in ("frames", "videos", "images"): |
| | if key in output and output[key] is not None: |
| | v = output[key] |
| | if key == "frames": |
| | arr = np.array(v) |
| | if arr.ndim == 4: |
| | frames = [arr[t] for t in range(arr.shape[0])] |
| | elif arr.ndim == 5: |
| | arr = arr[0] |
| | frames = [arr[t] for t in range(arr.shape[0])] |
| | elif isinstance(v, list): |
| | frames = [np.array(x) for x in v] |
| | else: |
| | raise ValueError(f"Unexpected dict frames shape: {arr.shape}") |
| | elif key == "videos": |
| | arr = np.array(v) |
| | if arr.ndim == 5: |
| | arr = arr[0] |
| | if arr.ndim == 4 and arr.shape[1] in (1, 3, 4): |
| | arr = np.transpose(arr, (0, 2, 3, 1)) |
| | frames = [arr[t] for t in range(arr.shape[0])] |
| | else: |
| | if isinstance(v, list): |
| | frames = [np.array(x) for x in v] |
| | else: |
| | frames = [np.array(v)] |
| | break |
| | |
| | if not frames: |
| | raise RuntimeError("Could not extract frames from pipeline output (no frames/videos/images found).") |
| | |
| | frames_u8 = [_clamp_uint8_frame(f) for f in frames] |
| | |
| | diag = { |
| | "prompt_len": len(params.prompt), |
| | "negative_prompt_len": len(params.negative_prompt), |
| | "num_frames": len(frames_u8), |
| | "height": int(frames_u8[0].shape[0]), |
| | "width": int(frames_u8[0].shape[1]), |
| | "mode": "i2v" if params.image_b64 else "t2v" |
| | } |
| | return frames_u8, diag |