| """ |
| Thin client for the encoder Hugging Face Space. |
| |
| Uses the official `gradio_client` so the URL contract isn't fragile across |
| Gradio versions. Reconstructs torch tensors from the base64 float16 payload. |
| |
| Env vars |
| -------- |
| ENCODER_SPACE_URL Public URL of the Space, e.g. |
| https://USER-pseudoscorex-encoder.hf.space OR USER/SPACE-NAME |
| Required. |
| HF_TOKEN Required only if the Space is private. |
| """ |
| import base64 |
| import logging |
| import os |
| import threading |
| import time |
|
|
| import numpy as np |
| import torch |
| from gradio_client import Client |
|
|
| logger = logging.getLogger("encoder_client") |
|
|
| ENCODER_SPACE_URL = os.getenv("ENCODER_SPACE_URL", "").strip() |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| _client = None |
| _client_lock = threading.Lock() |
|
|
|
|
| def _get_client() -> Client: |
| """Lazily build the Gradio client, thread-safe.""" |
| global _client |
| if _client is not None: |
| return _client |
| with _client_lock: |
| if _client is not None: |
| return _client |
| if not ENCODER_SPACE_URL: |
| raise RuntimeError( |
| "ENCODER_SPACE_URL is not set. Point it at your encoder Space, " |
| "e.g. https://USER-pseudoscorex-encoder.hf.space" |
| ) |
| logger.info("Connecting to encoder Space: %s", ENCODER_SPACE_URL) |
| |
| |
| if HF_TOKEN: |
| for kw in ("hf_token", "token"): |
| try: |
| _client = Client(ENCODER_SPACE_URL, **{kw: HF_TOKEN}) |
| break |
| except TypeError: |
| continue |
| else: |
| logger.warning( |
| "gradio_client did not accept hf_token/token; " |
| "connecting without auth (Space must be public)." |
| ) |
| _client = Client(ENCODER_SPACE_URL) |
| else: |
| _client = Client(ENCODER_SPACE_URL) |
| return _client |
|
|
|
|
| def encode_text(text: str, device: torch.device): |
| """ |
| Returns (hidden, attention_mask, clean_tokens): |
| hidden torch.float32 tensor of shape (1, seq_len, 1024) |
| attention_mask torch.long tensor of shape (1, seq_len) |
| clean_tokens list[str] |
| """ |
| client = _get_client() |
|
|
| t0 = time.perf_counter() |
| out = client.predict(text, api_name="/encode") |
| logger.debug("Space round-trip: %.2fs", time.perf_counter() - t0) |
|
|
| if not isinstance(out, dict) or "hidden_b64" not in out: |
| raise RuntimeError(f"Unexpected encoder Space response: {out!r}") |
|
|
| arr = np.frombuffer(base64.b64decode(out["hidden_b64"]), dtype=np.float16) |
| arr = arr.reshape(out["shape"]) |
|
|
| hidden = torch.from_numpy(arr.astype(np.float32)).unsqueeze(0).to(device) |
| mask = torch.tensor(out["attention_mask"], dtype=torch.long).unsqueeze(0).to(device) |
| return hidden, mask, out["clean_tokens"] |
|
|