| """ |
| Custom handler for BGE dense retrieval on HuggingFace Inference Endpoints. |
| |
| Returns dense embeddings for queries and passages. |
| |
| Key difference from ANCE: BGE requires an instruction prefix on queries |
| for retrieval tasks. Passages are encoded without any prefix. |
| """ |
|
|
| from typing import Any, Dict, List, Union |
|
|
| import torch |
| from transformers import AutoModel, AutoTokenizer |
|
|
|
|
| |
| QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: " |
|
|
|
|
| class EndpointHandler: |
| """Handler for BGE embedding generation.""" |
|
|
| def __init__(self, path: str = ""): |
| """Initialize the model and tokenizer.""" |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.model = AutoModel.from_pretrained(path) |
| self.model.eval() |
|
|
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = self.model.to(self.device) |
|
|
| print(f"BGE loaded on {self.device}") |
|
|
| def __call__(self, data: Dict[str, Any]) -> Union[List[List[float]], Dict[str, Any]]: |
| """ |
| Process inference requests. |
| |
| Accepts: |
| - {"inputs": "text"} - single text (encoded as passage, no prefix) |
| - {"inputs": ["text1", "text2", ...]} - batch of texts (encoded as passages) |
| - {"inputs": "text", "is_query": true} - single text with query prefix |
| - {"inputs": [...], "is_query": true} - batch with query prefix |
| - {"query": "...", "passages": ["...", ...]} - query + passages (returns similarity scores) |
| |
| Returns: |
| - {"embeddings": [...]} - list of embeddings |
| - Or {"scores": [...]} if query + passages provided |
| """ |
| inputs = data.get("inputs", None) |
| is_query = data.get("is_query", False) |
| query = data.get("query", None) |
| passages = data.get("passages", None) |
|
|
| |
| if query is not None and passages is not None: |
| |
| query_with_prefix = f"{QUERY_INSTRUCTION}{query}" |
| query_emb = self._encode([query_with_prefix])[0] |
|
|
| |
| passage_embs = self._encode(passages) |
|
|
| |
| scores = [] |
| for p_emb in passage_embs: |
| score = self._cosine_similarity(query_emb, p_emb) |
| scores.append(score) |
|
|
| return {"scores": scores} |
|
|
| |
| if inputs is None: |
| return {"error": "No inputs provided. Use 'inputs' or 'query'+'passages'."} |
|
|
| if isinstance(inputs, str): |
| inputs = [inputs] |
|
|
| |
| if is_query: |
| inputs = [f"{QUERY_INSTRUCTION}{text}" for text in inputs] |
|
|
| embeddings = self._encode(inputs) |
| return {"embeddings": embeddings} |
|
|
| def _encode(self, texts: List[str], max_length: int = 512) -> List[List[float]]: |
| """Encode texts into embeddings.""" |
| |
| encoded = self.tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt" |
| ).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model(**encoded) |
| |
| embeddings = outputs.last_hidden_state[:, 0, :] |
| |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
| return embeddings.cpu().tolist() |
|
|
| def _cosine_similarity(self, emb1: List[float], emb2: List[float]) -> float: |
| """Compute cosine similarity between two embeddings.""" |
| import math |
| dot = sum(a * b for a, b in zip(emb1, emb2)) |
| norm1 = math.sqrt(sum(a * a for a in emb1)) |
| norm2 = math.sqrt(sum(b * b for b in emb2)) |
| if norm1 == 0 or norm2 == 0: |
| return 0.0 |
| return dot / (norm1 * norm2) |
|
|