pseudoscorex-backend / model /encoder_client.py
Nanny7's picture
Initial commit: Flask backend for pseudoscore-x
219ee1a
"""
Thin client for the encoder Hugging Face Space.
Uses the official `gradio_client` so the URL contract isn't fragile across
Gradio versions. Reconstructs torch tensors from the base64 float16 payload.
Env vars
--------
ENCODER_SPACE_URL Public URL of the Space, e.g.
https://USER-pseudoscorex-encoder.hf.space OR USER/SPACE-NAME
Required.
HF_TOKEN Required only if the Space is private.
"""
import base64
import logging
import os
import threading
import time
import numpy as np
import torch
from gradio_client import Client
logger = logging.getLogger("encoder_client")
ENCODER_SPACE_URL = os.getenv("ENCODER_SPACE_URL", "").strip()
HF_TOKEN = os.getenv("HF_TOKEN")
_client = None
_client_lock = threading.Lock()
def _get_client() -> Client:
"""Lazily build the Gradio client, thread-safe."""
global _client
if _client is not None:
return _client
with _client_lock:
if _client is not None:
return _client
if not ENCODER_SPACE_URL:
raise RuntimeError(
"ENCODER_SPACE_URL is not set. Point it at your encoder Space, "
"e.g. https://USER-pseudoscorex-encoder.hf.space"
)
logger.info("Connecting to encoder Space: %s", ENCODER_SPACE_URL)
# gradio_client renamed/removed the auth kwarg between versions.
# Public Spaces don't need a token, so fall back to no-auth on TypeError.
if HF_TOKEN:
for kw in ("hf_token", "token"):
try:
_client = Client(ENCODER_SPACE_URL, **{kw: HF_TOKEN})
break
except TypeError:
continue
else:
logger.warning(
"gradio_client did not accept hf_token/token; "
"connecting without auth (Space must be public)."
)
_client = Client(ENCODER_SPACE_URL)
else:
_client = Client(ENCODER_SPACE_URL)
return _client
def encode_text(text: str, device: torch.device):
"""
Returns (hidden, attention_mask, clean_tokens):
hidden torch.float32 tensor of shape (1, seq_len, 1024)
attention_mask torch.long tensor of shape (1, seq_len)
clean_tokens list[str]
"""
client = _get_client()
t0 = time.perf_counter()
out = client.predict(text, api_name="/encode")
logger.debug("Space round-trip: %.2fs", time.perf_counter() - t0)
if not isinstance(out, dict) or "hidden_b64" not in out:
raise RuntimeError(f"Unexpected encoder Space response: {out!r}")
arr = np.frombuffer(base64.b64decode(out["hidden_b64"]), dtype=np.float16)
arr = arr.reshape(out["shape"]) # (seq_len, 1024)
hidden = torch.from_numpy(arr.astype(np.float32)).unsqueeze(0).to(device)
mask = torch.tensor(out["attention_mask"], dtype=torch.long).unsqueeze(0).to(device)
return hidden, mask, out["clean_tokens"]