Spaces:
Runtime error
Runtime error
File size: 7,539 Bytes
d6585f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import numpy as np
from typing import List, Dict
from pyserini.search.faiss import PRFDenseSearchResult, AnceQueryEncoder
from pyserini.search.lucene import LuceneSearcher
import json
class DenseVectorPrf:
def __init__(self):
pass
def get_prf_q_emb(self, **kwargs):
pass
def get_batch_prf_q_emb(self, **kwargs):
pass
class DenseVectorAveragePrf(DenseVectorPrf):
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform Average PRF with Dense Vectors
Parameters
----------
emb_qs : np.ndarray
Query embedding
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
all_candidate_embs = [item.vectors for item in prf_candidates]
new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0)
return new_emb_qs
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
"""Perform Average PRF with Dense Vectors
Parameters
----------
topic_ids : List[str]
List of topic ids.
emb_qs : np.ndarray
Query embeddings
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
qids = list()
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs
class DenseVectorRocchioPrf(DenseVectorPrf):
def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int):
"""
Parameters
----------
alpha : float
Rocchio parameter, controls the weight assigned to the original query embedding.
beta : float
Rocchio parameter, controls the weight assigned to the positive document embeddings.
gamma : float
Rocchio parameter, controls the weight assigned to the negative document embeddings.
topk : int
Rocchio parameter, set topk documents as positive document feedbacks.
bottomk : int
Rocchio parameter, set bottomk documents as negative document feedbacks.
"""
DenseVectorPrf.__init__(self)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.topk = topk
self.bottomk = bottomk
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform Rocchio PRF with Dense Vectors
Parameters
----------
emb_qs : np.ndarray
query embedding
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
all_candidate_embs = [item.vectors for item in prf_candidates]
weighted_query_embs = self.alpha * emb_qs
weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0)
new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs
if self.bottomk > 0:
weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0)
new_emb_q -= weighted_mean_neg_doc_embs
return new_emb_q
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
"""Perform Rocchio PRF with Dense Vectors
Parameters
----------
topic_ids : List[str]
List of topic ids.
emb_qs : np.ndarray
Query embeddings
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
qids = list()
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs
class DenseVectorAncePrf(DenseVectorPrf):
def __init__(self, encoder: AnceQueryEncoder, sparse_searcher: LuceneSearcher):
"""
Parameters
----------
encoder : AnceQueryEncoder
The new ANCE query encoder for ANCE-PRF.
sparse_searcher : LuceneSearcher
The sparse searcher using lucene index, for retrieving doc contents.
"""
DenseVectorPrf.__init__(self)
self.encoder = encoder
self.sparse_searcher = sparse_searcher
def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform single ANCE-PRF with Dense Vectors
Parameters
----------
query : str
query text
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
passage_texts = [query]
for item in prf_candidates:
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
passage_texts.append(raw_text['contents'])
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
emb_q = self.encoder.prf_encode(full_text)
emb_q = emb_q.reshape((1, len(emb_q)))
return emb_q
def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str],
prf_candidates: Dict[str, List[PRFDenseSearchResult]]) -> np.ndarray:
"""Perform batch ANCE-PRF with Dense Vectors
Parameters
----------
topics : List[str]
List of query texts.
topic_ids: List[str]
List of topic ids.
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
prf_passage_texts = list()
for index, query in enumerate(topics):
passage_texts = [query]
prf_candidate = prf_candidates[topic_ids[index]]
for item in prf_candidate:
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
passage_texts.append(raw_text['contents'])
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
prf_passage_texts.append(full_text)
emb_q = self.encoder.prf_batch_encode(prf_passage_texts)
return emb_q
|