import logging from datetime import datetime from typing import Dict, List, AnyStr from sentence_transformers import CrossEncoder import torch logger = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): device = "cuda" if torch.cuda.is_available() else "cpu" self.cross_encoder = CrossEncoder(path, device=device) def __call__(self, data: Dict[str, AnyStr]) -> Dict[str, List[float]]: """ Args: data (Dict[str, AnyStr]): A dictionary containing the input data and parameters for inference. The input data should include a "query" and a list of "passages". Return: Dict[str, List[float]]: A dictionary with a single key "scores", containing a list of floating point numbers. Each number represents the score of a passage for the given query. The order of the scores matches the order of the passages. """ inputs = data.get("inputs") query = inputs.get("query") passages = inputs.get("passages") logger.info(f"Query: {query}") logger.info(f"N. of passages: {len(passages)}") start_time = datetime.now() scores = self.cross_encoder.predict([(query, passage) for passage in passages], activation_fct=torch.nn.Sigmoid()) logger.info(f"Time to run cross-encoder for query '{query}' with {len(passages)} passages: {datetime.now() - start_time}") logger.info(f"Scores: {scores}") return { "scores": scores.tolist() }