Spaces:
Sleeping
Sleeping
| """ | |
| eduai-embedder — tiny embedding microservice. | |
| One process, one model, three routes. Deployed as a Docker Space on | |
| HuggingFace and called by `eduai_platform` (and any other EduAI service | |
| that needs embeddings) so individual developers don't have to install | |
| torch + sentence-transformers locally. | |
| Endpoints | |
| --------- | |
| GET /health → {status, model, dim} | |
| POST /embed → {embeddings: [[float]], model, dim} | |
| POST /embed_one → {embedding: [float], model, dim} | |
| Authentication | |
| -------------- | |
| If the `EMBEDDER_API_KEY` env var is set, all routes except /health | |
| require an `X-API-Key` header that matches it. Leave it unset only for | |
| local dev (the default in `.env.example` makes you set one). | |
| Configuration (env vars) | |
| ------------------------ | |
| EMBEDDER_MODEL_NAME sentence-transformers model id (default: all-MiniLM-L6-v2) | |
| EMBEDDER_API_KEY shared secret; if set, required on /embed* routes | |
| EMBEDDER_MAX_BATCH reject batches larger than this (default: 128) | |
| EMBEDDER_MAX_TEXT_LEN reject texts longer than this many characters (default: 8000) | |
| EMBEDDER_CORS comma-separated allow-origins (default: *) | |
| """ | |
| import logging | |
| import os | |
| from typing import List, Optional | |
| from fastapi import Depends, FastAPI, Header, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from sentence_transformers import SentenceTransformer | |
| # ----------------------------------------------------------------------------- config | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", | |
| ) | |
| log = logging.getLogger("eduai-embedder") | |
| MODEL_NAME = os.getenv("EMBEDDER_MODEL_NAME", "all-MiniLM-L6-v2") | |
| API_KEY = os.getenv("EMBEDDER_API_KEY", "") | |
| MAX_BATCH = int(os.getenv("EMBEDDER_MAX_BATCH", "128")) | |
| MAX_TEXT_LEN = int(os.getenv("EMBEDDER_MAX_TEXT_LEN", "8000")) | |
| CORS_ORIGINS = [o.strip() for o in os.getenv("EMBEDDER_CORS", "*").split(",") if o.strip()] | |
| # ----------------------------------------------------------------------------- model | |
| log.info("Loading sentence-transformers model: %s ...", MODEL_NAME) | |
| _model = SentenceTransformer(MODEL_NAME) | |
| DIM = _model.get_sentence_embedding_dimension() | |
| log.info("Model loaded (dim=%d, normalize_embeddings=True)", DIM) | |
| # ----------------------------------------------------------------------------- app | |
| app = FastAPI( | |
| title="eduai-embedder", | |
| description="Tiny embedding microservice for the EduAI platform.", | |
| version="0.1.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=CORS_ORIGINS, | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| # ----------------------------------------------------------------------------- schemas | |
| class EmbedBatchIn(BaseModel): | |
| texts: List[str] = Field(..., min_length=1, description="Texts to embed.") | |
| class EmbedOneIn(BaseModel): | |
| text: str = Field(..., min_length=1) | |
| class EmbedOut(BaseModel): | |
| embeddings: List[List[float]] | |
| model: str | |
| dim: int | |
| class EmbedOneOut(BaseModel): | |
| embedding: List[float] | |
| model: str | |
| dim: int | |
| class HealthOut(BaseModel): | |
| status: str | |
| model: str | |
| dim: int | |
| # ----------------------------------------------------------------------------- auth | |
| def require_api_key(x_api_key: Optional[str] = Header(default=None, alias="X-API-Key")) -> None: | |
| """Reject requests if EMBEDDER_API_KEY is set and the header doesn't match.""" | |
| if not API_KEY: | |
| return # open mode (intended for local dev only) | |
| if x_api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid or missing API key.") | |
| # ----------------------------------------------------------------------------- routes | |
| def health() -> HealthOut: | |
| """Liveness probe. Always public; HF Spaces' built-in checks rely on this.""" | |
| return HealthOut(status="ok", model=MODEL_NAME, dim=DIM) | |
| def embed_batch(body: EmbedBatchIn) -> EmbedOut: | |
| """Embed a batch of texts. Vectors are L2-normalized for cosine similarity.""" | |
| if len(body.texts) > MAX_BATCH: | |
| raise HTTPException(status_code=400, detail=f"Batch too large (max {MAX_BATCH}).") | |
| for i, text in enumerate(body.texts): | |
| if len(text) > MAX_TEXT_LEN: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Text at index {i} too long (max {MAX_TEXT_LEN} characters).", | |
| ) | |
| vectors = _model.encode( | |
| body.texts, | |
| normalize_embeddings=True, | |
| batch_size=64, | |
| ).tolist() | |
| return EmbedOut(embeddings=vectors, model=MODEL_NAME, dim=DIM) | |
| def embed_one(body: EmbedOneIn) -> EmbedOneOut: | |
| """Embed a single text — convenience for chat query embeddings.""" | |
| if len(body.text) > MAX_TEXT_LEN: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Text too long (max {MAX_TEXT_LEN} characters).", | |
| ) | |
| vector = _model.encode(body.text, normalize_embeddings=True).tolist() | |
| return EmbedOneOut(embedding=vector, model=MODEL_NAME, dim=DIM) | |