File size: 2,498 Bytes
f7b283c
 
7a0020b
f7b283c
 
 
 
 
 
cc2577d
f7b283c
 
7a0020b
cc2577d
7a0020b
cc2577d
7a0020b
 
f7b283c
 
7a0020b
f7b283c
 
 
 
 
 
 
 
cc2577d
7a0020b
 
 
 
 
cc2577d
f7b283c
cc2577d
f7b283c
 
 
 
 
 
c7c1b4e
 
f7b283c
 
cc2577d
 
 
f7b283c
 
 
cc2577d
f7b283c
 
7a0020b
 
 
cc2577d
7a0020b
 
 
f7b283c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
from typing import List

from logger_config import config_logger
logger = config_logger(__name__)

class CrossEncoderReranker:
    """
    Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1].
    """
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
        """
        Init the cross-encoder with a pretrained model.
        Args:
            model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification.
        """
        logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
        logger.info("Cross encoder model loaded successfully.")

    def rerank(
        self,
        query: str,
        candidates: List[str],
        max_length: int = 256
    ) -> List[float]:
        """
        Compute relevance scores for each candidate w.r.t. query.
        Args:
            query: User's query text.
            candidates: List of candidate response texts.
            max_length: Max token length for each (query, candidate) pair.
        Returns:
            A list of float scores [0...1]. One per candidate, indicating model's predicted relevance.
        """
        # Build (query, candidate) pairs, then tokenize
        pair_texts = [(query, candidate) for candidate in candidates]
        encodings = self.tokenizer(
            pair_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="tf",
            verbose=False
        )

        # Forward pass, logits shape [batch_size, 1]
        # Then convert logits to [0...1] range with sigmoid
        # Note: token_type_ids are optional. .get() avoids KeyError
        outputs = self.model(
            input_ids=encodings["input_ids"],
            attention_mask=encodings["attention_mask"],
            token_type_ids=encodings.get("token_type_ids")
        )
        
        logits = outputs.logits  # shape [batch_size, 1]
        scores = tf.nn.sigmoid(logits)  # shape [batch_size, 1]

        # Flatten to 1D NumPy array, ensure float type
        scores = tf.reshape(scores, [-1])
        scores = scores.numpy().astype(float)

        return scores.tolist()