File size: 7,539 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import numpy as np
from typing import List, Dict
from pyserini.search.faiss import PRFDenseSearchResult, AnceQueryEncoder
from pyserini.search.lucene import LuceneSearcher
import json


class DenseVectorPrf:
    def __init__(self):
        pass

    def get_prf_q_emb(self, **kwargs):
        pass

    def get_batch_prf_q_emb(self, **kwargs):
        pass


class DenseVectorAveragePrf(DenseVectorPrf):

    def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
        """Perform Average PRF with Dense Vectors

        Parameters
        ----------
        emb_qs : np.ndarray
            Query embedding
        prf_candidates : List[PRFDenseSearchResult]
            List of PRFDenseSearchResult, contains document embeddings.

        Returns
        -------
        np.ndarray
            return new query embeddings
        """
        all_candidate_embs = [item.vectors for item in prf_candidates]
        new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0)
        return new_emb_qs

    def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
                            prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
        """Perform Average PRF with Dense Vectors

        Parameters
        ----------
        topic_ids : List[str]
            List of topic ids.
        emb_qs : np.ndarray
            Query embeddings
        prf_candidates : List[PRFDenseSearchResult]
            List of PRFDenseSearchResult, contains document embeddings.

        Returns
        -------
        np.ndarray
            return new query embeddings
        """

        qids = list()
        new_emb_qs = list()
        for index, topic_id in enumerate(topic_ids):
            qids.append(topic_id)
            new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
        new_emb_qs = np.array(new_emb_qs).astype('float32')
        return new_emb_qs


class DenseVectorRocchioPrf(DenseVectorPrf):
    def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int):
        """
        Parameters
        ----------
        alpha : float
            Rocchio parameter, controls the weight assigned to the original query embedding.
        beta : float
            Rocchio parameter, controls the weight assigned to the positive document embeddings.
        gamma : float
            Rocchio parameter, controls the weight assigned to the negative document embeddings.
        topk : int
            Rocchio parameter, set topk documents as positive document feedbacks.
        bottomk : int
            Rocchio parameter, set bottomk documents as negative document feedbacks.
        """
        DenseVectorPrf.__init__(self)
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.topk = topk
        self.bottomk = bottomk

    def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
        """Perform Rocchio PRF with Dense Vectors

        Parameters
        ----------
        emb_qs : np.ndarray
            query embedding
        prf_candidates : List[PRFDenseSearchResult]
            List of PRFDenseSearchResult, contains document embeddings.

        Returns
        -------
        np.ndarray
            return new query embeddings
        """

        all_candidate_embs = [item.vectors for item in prf_candidates]
        weighted_query_embs = self.alpha * emb_qs
        weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0)
        new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs
        if self.bottomk > 0:
            weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0)
            new_emb_q -= weighted_mean_neg_doc_embs
        return new_emb_q

    def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
                            prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
        """Perform Rocchio PRF with Dense Vectors

        Parameters
        ----------
        topic_ids : List[str]
            List of topic ids.
        emb_qs : np.ndarray
            Query embeddings
        prf_candidates : List[PRFDenseSearchResult]
            List of PRFDenseSearchResult, contains document embeddings.

        Returns
        -------
        np.ndarray
            return new query embeddings
        """
        qids = list()
        new_emb_qs = list()
        for index, topic_id in enumerate(topic_ids):
            qids.append(topic_id)
            new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
        new_emb_qs = np.array(new_emb_qs).astype('float32')
        return new_emb_qs


class DenseVectorAncePrf(DenseVectorPrf):
    def __init__(self, encoder: AnceQueryEncoder, sparse_searcher: LuceneSearcher):
        """
        Parameters
        ----------
        encoder : AnceQueryEncoder
            The new ANCE query encoder for ANCE-PRF.
        sparse_searcher : LuceneSearcher
            The sparse searcher using lucene index, for retrieving doc contents.
        """
        DenseVectorPrf.__init__(self)
        self.encoder = encoder
        self.sparse_searcher = sparse_searcher

    def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchResult] = None):
        """Perform single ANCE-PRF with Dense Vectors

        Parameters
        ----------
        query : str
            query text
        prf_candidates : List[PRFDenseSearchResult]
            List of PRFDenseSearchResult, contains document embeddings.

        Returns
        -------
        np.ndarray
            return new query embeddings
        """
        passage_texts = [query]
        for item in prf_candidates:
            raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
            passage_texts.append(raw_text['contents'])
        full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
        emb_q = self.encoder.prf_encode(full_text)
        emb_q = emb_q.reshape((1, len(emb_q)))
        return emb_q

    def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str],
                            prf_candidates: Dict[str, List[PRFDenseSearchResult]]) -> np.ndarray:
        """Perform batch ANCE-PRF with Dense Vectors

        Parameters
        ----------
        topics : List[str]
            List of query texts.
        topic_ids: List[str]
            List of topic ids.
        prf_candidates : List[PRFDenseSearchResult]
            List of PRFDenseSearchResult, contains document embeddings.

        Returns
        -------
        np.ndarray
            return new query embeddings
        """
        prf_passage_texts = list()
        for index, query in enumerate(topics):
            passage_texts = [query]
            prf_candidate = prf_candidates[topic_ids[index]]
            for item in prf_candidate:
                raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
                passage_texts.append(raw_text['contents'])
            full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
            prf_passage_texts.append(full_text)
        emb_q = self.encoder.prf_batch_encode(prf_passage_texts)
        return emb_q