from transformers import AutoTokenizer, TFAutoModelForSequenceClassification import tensorflow as tf from typing import List from logger_config import config_logger logger = config_logger(__name__) class CrossEncoderReranker: """ Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1]. """ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"): """ Init the cross-encoder with a pretrained model. Args: model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification. """ logger.info(f"Initializing CrossEncoderReranker with {model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name) logger.info("Cross encoder model loaded successfully.") def rerank( self, query: str, candidates: List[str], max_length: int = 256 ) -> List[float]: """ Compute relevance scores for each candidate w.r.t. query. Args: query: User's query text. candidates: List of candidate response texts. max_length: Max token length for each (query, candidate) pair. Returns: A list of float scores [0...1]. One per candidate, indicating model's predicted relevance. """ # Build (query, candidate) pairs, then tokenize pair_texts = [(query, candidate) for candidate in candidates] encodings = self.tokenizer( pair_texts, padding=True, truncation=True, max_length=max_length, return_tensors="tf", verbose=False ) # Forward pass, logits shape [batch_size, 1] # Then convert logits to [0...1] range with sigmoid # Note: token_type_ids are optional. .get() avoids KeyError outputs = self.model( input_ids=encodings["input_ids"], attention_mask=encodings["attention_mask"], token_type_ids=encodings.get("token_type_ids") ) logits = outputs.logits # shape [batch_size, 1] scores = tf.nn.sigmoid(logits) # shape [batch_size, 1] # Flatten to 1D NumPy array, ensure float type scores = tf.reshape(scores, [-1]) scores = scores.numpy().astype(float) return scores.tolist()