sabaimran
Add handler and requirements.txt to setup our custom inference endpoint handler
0d15560
raw
history blame
1.55 kB
import logging
from datetime import datetime
from typing import Dict, List, AnyStr
from sentence_transformers import CrossEncoder
import torch
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.cross_encoder = CrossEncoder(path, device=device)
def __call__(self, data: Dict[str, AnyStr]) -> Dict[str, List[float]]:
"""
Args:
data (Dict[str, AnyStr]): A dictionary containing the input data and parameters for inference.
The input data should include a "query" and a list of "passages".
Return:
Dict[str, List[float]]: A dictionary with a single key "scores", containing a list of floating point numbers.
Each number represents the score of a passage for the given query. The order of the scores matches the order of the passages.
"""
inputs = data.get("inputs")
query = inputs.get("query")
passages = inputs.get("passages")
logger.info(f"Query: {query}")
logger.info(f"N. of passages: {len(passages)}")
start_time = datetime.now()
scores = self.cross_encoder.predict([(query, passage) for passage in passages], activation_fct=torch.nn.Sigmoid())
logger.info(f"Time to run cross-encoder for query '{query}' with {len(passages)} passages: {datetime.now() - start_time}")
logger.info(f"Scores: {scores}")
return {
"scores": scores
}