Spaces:
Runtime error
Runtime error
import numpy as np | |
from typing import List, Dict | |
from pyserini.search.faiss import PRFDenseSearchResult, AnceQueryEncoder | |
from pyserini.search.lucene import LuceneSearcher | |
import json | |
class DenseVectorPrf: | |
def __init__(self): | |
pass | |
def get_prf_q_emb(self, **kwargs): | |
pass | |
def get_batch_prf_q_emb(self, **kwargs): | |
pass | |
class DenseVectorAveragePrf(DenseVectorPrf): | |
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): | |
"""Perform Average PRF with Dense Vectors | |
Parameters | |
---------- | |
emb_qs : np.ndarray | |
Query embedding | |
prf_candidates : List[PRFDenseSearchResult] | |
List of PRFDenseSearchResult, contains document embeddings. | |
Returns | |
------- | |
np.ndarray | |
return new query embeddings | |
""" | |
all_candidate_embs = [item.vectors for item in prf_candidates] | |
new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0) | |
return new_emb_qs | |
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, | |
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None): | |
"""Perform Average PRF with Dense Vectors | |
Parameters | |
---------- | |
topic_ids : List[str] | |
List of topic ids. | |
emb_qs : np.ndarray | |
Query embeddings | |
prf_candidates : List[PRFDenseSearchResult] | |
List of PRFDenseSearchResult, contains document embeddings. | |
Returns | |
------- | |
np.ndarray | |
return new query embeddings | |
""" | |
qids = list() | |
new_emb_qs = list() | |
for index, topic_id in enumerate(topic_ids): | |
qids.append(topic_id) | |
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id])) | |
new_emb_qs = np.array(new_emb_qs).astype('float32') | |
return new_emb_qs | |
class DenseVectorRocchioPrf(DenseVectorPrf): | |
def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int): | |
""" | |
Parameters | |
---------- | |
alpha : float | |
Rocchio parameter, controls the weight assigned to the original query embedding. | |
beta : float | |
Rocchio parameter, controls the weight assigned to the positive document embeddings. | |
gamma : float | |
Rocchio parameter, controls the weight assigned to the negative document embeddings. | |
topk : int | |
Rocchio parameter, set topk documents as positive document feedbacks. | |
bottomk : int | |
Rocchio parameter, set bottomk documents as negative document feedbacks. | |
""" | |
DenseVectorPrf.__init__(self) | |
self.alpha = alpha | |
self.beta = beta | |
self.gamma = gamma | |
self.topk = topk | |
self.bottomk = bottomk | |
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): | |
"""Perform Rocchio PRF with Dense Vectors | |
Parameters | |
---------- | |
emb_qs : np.ndarray | |
query embedding | |
prf_candidates : List[PRFDenseSearchResult] | |
List of PRFDenseSearchResult, contains document embeddings. | |
Returns | |
------- | |
np.ndarray | |
return new query embeddings | |
""" | |
all_candidate_embs = [item.vectors for item in prf_candidates] | |
weighted_query_embs = self.alpha * emb_qs | |
weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0) | |
new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs | |
if self.bottomk > 0: | |
weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0) | |
new_emb_q -= weighted_mean_neg_doc_embs | |
return new_emb_q | |
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, | |
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None): | |
"""Perform Rocchio PRF with Dense Vectors | |
Parameters | |
---------- | |
topic_ids : List[str] | |
List of topic ids. | |
emb_qs : np.ndarray | |
Query embeddings | |
prf_candidates : List[PRFDenseSearchResult] | |
List of PRFDenseSearchResult, contains document embeddings. | |
Returns | |
------- | |
np.ndarray | |
return new query embeddings | |
""" | |
qids = list() | |
new_emb_qs = list() | |
for index, topic_id in enumerate(topic_ids): | |
qids.append(topic_id) | |
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id])) | |
new_emb_qs = np.array(new_emb_qs).astype('float32') | |
return new_emb_qs | |
class DenseVectorAncePrf(DenseVectorPrf): | |
def __init__(self, encoder: AnceQueryEncoder, sparse_searcher: LuceneSearcher): | |
""" | |
Parameters | |
---------- | |
encoder : AnceQueryEncoder | |
The new ANCE query encoder for ANCE-PRF. | |
sparse_searcher : LuceneSearcher | |
The sparse searcher using lucene index, for retrieving doc contents. | |
""" | |
DenseVectorPrf.__init__(self) | |
self.encoder = encoder | |
self.sparse_searcher = sparse_searcher | |
def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchResult] = None): | |
"""Perform single ANCE-PRF with Dense Vectors | |
Parameters | |
---------- | |
query : str | |
query text | |
prf_candidates : List[PRFDenseSearchResult] | |
List of PRFDenseSearchResult, contains document embeddings. | |
Returns | |
------- | |
np.ndarray | |
return new query embeddings | |
""" | |
passage_texts = [query] | |
for item in prf_candidates: | |
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw()) | |
passage_texts.append(raw_text['contents']) | |
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}' | |
emb_q = self.encoder.prf_encode(full_text) | |
emb_q = emb_q.reshape((1, len(emb_q))) | |
return emb_q | |
def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str], | |
prf_candidates: Dict[str, List[PRFDenseSearchResult]]) -> np.ndarray: | |
"""Perform batch ANCE-PRF with Dense Vectors | |
Parameters | |
---------- | |
topics : List[str] | |
List of query texts. | |
topic_ids: List[str] | |
List of topic ids. | |
prf_candidates : List[PRFDenseSearchResult] | |
List of PRFDenseSearchResult, contains document embeddings. | |
Returns | |
------- | |
np.ndarray | |
return new query embeddings | |
""" | |
prf_passage_texts = list() | |
for index, query in enumerate(topics): | |
passage_texts = [query] | |
prf_candidate = prf_candidates[topic_ids[index]] | |
for item in prf_candidate: | |
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw()) | |
passage_texts.append(raw_text['contents']) | |
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}' | |
prf_passage_texts.append(full_text) | |
emb_q = self.encoder.prf_batch_encode(prf_passage_texts) | |
return emb_q | |