geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame contribute delete
No virus
7.54 kB
import numpy as np
from typing import List, Dict
from pyserini.search.faiss import PRFDenseSearchResult, AnceQueryEncoder
from pyserini.search.lucene import LuceneSearcher
import json
class DenseVectorPrf:
def __init__(self):
pass
def get_prf_q_emb(self, **kwargs):
pass
def get_batch_prf_q_emb(self, **kwargs):
pass
class DenseVectorAveragePrf(DenseVectorPrf):
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform Average PRF with Dense Vectors
Parameters
----------
emb_qs : np.ndarray
Query embedding
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
all_candidate_embs = [item.vectors for item in prf_candidates]
new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0)
return new_emb_qs
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
"""Perform Average PRF with Dense Vectors
Parameters
----------
topic_ids : List[str]
List of topic ids.
emb_qs : np.ndarray
Query embeddings
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
qids = list()
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs
class DenseVectorRocchioPrf(DenseVectorPrf):
def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int):
"""
Parameters
----------
alpha : float
Rocchio parameter, controls the weight assigned to the original query embedding.
beta : float
Rocchio parameter, controls the weight assigned to the positive document embeddings.
gamma : float
Rocchio parameter, controls the weight assigned to the negative document embeddings.
topk : int
Rocchio parameter, set topk documents as positive document feedbacks.
bottomk : int
Rocchio parameter, set bottomk documents as negative document feedbacks.
"""
DenseVectorPrf.__init__(self)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.topk = topk
self.bottomk = bottomk
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform Rocchio PRF with Dense Vectors
Parameters
----------
emb_qs : np.ndarray
query embedding
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
all_candidate_embs = [item.vectors for item in prf_candidates]
weighted_query_embs = self.alpha * emb_qs
weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0)
new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs
if self.bottomk > 0:
weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0)
new_emb_q -= weighted_mean_neg_doc_embs
return new_emb_q
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
"""Perform Rocchio PRF with Dense Vectors
Parameters
----------
topic_ids : List[str]
List of topic ids.
emb_qs : np.ndarray
Query embeddings
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
qids = list()
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs
class DenseVectorAncePrf(DenseVectorPrf):
def __init__(self, encoder: AnceQueryEncoder, sparse_searcher: LuceneSearcher):
"""
Parameters
----------
encoder : AnceQueryEncoder
The new ANCE query encoder for ANCE-PRF.
sparse_searcher : LuceneSearcher
The sparse searcher using lucene index, for retrieving doc contents.
"""
DenseVectorPrf.__init__(self)
self.encoder = encoder
self.sparse_searcher = sparse_searcher
def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform single ANCE-PRF with Dense Vectors
Parameters
----------
query : str
query text
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
passage_texts = [query]
for item in prf_candidates:
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
passage_texts.append(raw_text['contents'])
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
emb_q = self.encoder.prf_encode(full_text)
emb_q = emb_q.reshape((1, len(emb_q)))
return emb_q
def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str],
prf_candidates: Dict[str, List[PRFDenseSearchResult]]) -> np.ndarray:
"""Perform batch ANCE-PRF with Dense Vectors
Parameters
----------
topics : List[str]
List of query texts.
topic_ids: List[str]
List of topic ids.
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
prf_passage_texts = list()
for index, query in enumerate(topics):
passage_texts = [query]
prf_candidate = prf_candidates[topic_ids[index]]
for item in prf_candidate:
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
passage_texts.append(raw_text['contents'])
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
prf_passage_texts.append(full_text)
emb_q = self.encoder.prf_batch_encode(prf_passage_texts)
return emb_q