matter / transformers_runtime.py
ashu1069's picture
fix: move runtime init lock to module level (ZeroGPU pickling)
726cb1c
"""HuggingFace transformers runtime β€” implements matter.engine.Runtime.
Loads Gemma 4 lazily on first inference (so cold Spaces serve the demo-mode path
without ever paying the load cost) and wraps inference in @spaces.GPU so the
Space's ZeroGPU pool only spins up while we're actually generating.
Picks Gemma 4 E2B (5B, any-to-any, instruction-tuned) by default. Override via
the MATTER_MODEL_ID Space secret.
"""
from __future__ import annotations
import os
import threading
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
try:
import spaces # type: ignore
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
DEFAULT_MODEL_ID = os.environ.get("MATTER_MODEL_ID", "google/gemma-4-E2B-it")
DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MATTER_MAX_NEW_TOKENS", "1024"))
DEFAULT_LORA_ID = os.environ.get("MATTER_LORA_ID", "").strip() or None
# Module-level init lock (must NOT be an instance attribute β€” `self` gets
# pickled across the ZeroGPU process boundary, and threading.Lock can't
# pickle). Modules are imported per-process so this lock is per-process,
# which is exactly the granularity we want.
_LOAD_LOCK = threading.Lock()
def _gpu_decorator(fn):
"""No-op when running locally (no `spaces` module), real decorator on HF."""
if HAS_SPACES:
return spaces.GPU(duration=90)(fn)
return fn
class TransformersRuntime:
"""Implements matter.engine.Runtime over HF transformers + Gemma 4."""
# Passport schema's provenance.runtime enum doesn't include "transformers"
# β€” report as "other" and surface the actual stack via model_id.
name: Literal["other"] = "other"
def __init__(
self,
model: str = DEFAULT_MODEL_ID,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
lora_id: str | None = DEFAULT_LORA_ID,
):
self.model_id = model
self.lora_id = lora_id
self.max_new_tokens = max_new_tokens
self._model = None
self._processor = None
def _ensure_loaded(self) -> None:
# Fast path: already loaded, no lock needed.
if self._model is not None:
return
# Module-level lock guards against concurrent first-call races. Two
# users hitting a cold Space simultaneously could both enter
# from_pretrained without this lock and double-allocate, OOM'ing CUDA.
with _LOAD_LOCK:
# Double-checked locking: another thread may have completed the
# load while we were waiting for the lock.
if self._model is not None:
return
from transformers import AutoModelForImageTextToText, AutoProcessor
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(self.model_id)
model = AutoModelForImageTextToText.from_pretrained(
self.model_id,
torch_dtype=dtype,
device_map=device,
)
if self.lora_id:
try:
from peft import PeftModel
model = PeftModel.from_pretrained(model, self.lora_id)
except Exception as e:
print(f"[TransformersRuntime] LoRA load failed ({self.lora_id}): {e}")
model.eval()
# Publish atomically β€” readers without the lock should never see a
# half-initialized state.
self._processor = processor
self._model = model
def infer(self, prompt: str, image: Path | None) -> str:
return self._infer_gpu(prompt, str(image) if image is not None else None)
@_gpu_decorator
def _infer_gpu(self, prompt: str, image_path: str | None) -> str:
self._ensure_loaded()
proc = self._processor
model = self._model
# Image first, then text β€” per the official google/gemma-4-E2B-it usage.
content: list[dict] = []
if image_path:
content.append({"type": "image", "image": Image.open(image_path).convert("RGB")})
content.append({"type": "text", "text": prompt})
messages = [{"role": "user", "content": content}]
inputs = proc.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
)
# Per Gemma 4 docs: decode with special tokens, then let the processor
# parse them out cleanly via parse_response().
raw = proc.decode(outputs[0][input_len:], skip_special_tokens=False)
if hasattr(proc, "parse_response"):
parsed = proc.parse_response(raw)
if isinstance(parsed, str):
return parsed
if isinstance(parsed, dict) and "content" in parsed:
return parsed["content"] if isinstance(parsed["content"], str) else str(parsed["content"])
return str(parsed)
return proc.decode(outputs[0][input_len:], skip_special_tokens=True)
__all__ = ["TransformersRuntime"]