geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
7.54 kB
import numpy as np
from typing import List, Dict
from pyserini.search.faiss import PRFDenseSearchResult, AnceQueryEncoder
from pyserini.search.lucene import LuceneSearcher
import json
class DenseVectorPrf:
def __init__(self):
pass
def get_prf_q_emb(self, **kwargs):
pass
def get_batch_prf_q_emb(self, **kwargs):
pass
class DenseVectorAveragePrf(DenseVectorPrf):
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform Average PRF with Dense Vectors
Parameters
----------
emb_qs : np.ndarray
Query embedding
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
all_candidate_embs = [item.vectors for item in prf_candidates]
new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0)
return new_emb_qs
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
"""Perform Average PRF with Dense Vectors
Parameters
----------
topic_ids : List[str]
List of topic ids.
emb_qs : np.ndarray
Query embeddings
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
qids = list()
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs
class DenseVectorRocchioPrf(DenseVectorPrf):
def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int):
"""
Parameters
----------
alpha : float
Rocchio parameter, controls the weight assigned to the original query embedding.
beta : float
Rocchio parameter, controls the weight assigned to the positive document embeddings.
gamma : float
Rocchio parameter, controls the weight assigned to the negative document embeddings.
topk : int
Rocchio parameter, set topk documents as positive document feedbacks.
bottomk : int
Rocchio parameter, set bottomk documents as negative document feedbacks.
"""
DenseVectorPrf.__init__(self)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.topk = topk
self.bottomk = bottomk
def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform Rocchio PRF with Dense Vectors
Parameters
----------
emb_qs : np.ndarray
query embedding
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
all_candidate_embs = [item.vectors for item in prf_candidates]
weighted_query_embs = self.alpha * emb_qs
weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0)
new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs
if self.bottomk > 0:
weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0)
new_emb_q -= weighted_mean_neg_doc_embs
return new_emb_q
def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None,
prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None):
"""Perform Rocchio PRF with Dense Vectors
Parameters
----------
topic_ids : List[str]
List of topic ids.
emb_qs : np.ndarray
Query embeddings
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
qids = list()
new_emb_qs = list()
for index, topic_id in enumerate(topic_ids):
qids.append(topic_id)
new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id]))
new_emb_qs = np.array(new_emb_qs).astype('float32')
return new_emb_qs
class DenseVectorAncePrf(DenseVectorPrf):
def __init__(self, encoder: AnceQueryEncoder, sparse_searcher: LuceneSearcher):
"""
Parameters
----------
encoder : AnceQueryEncoder
The new ANCE query encoder for ANCE-PRF.
sparse_searcher : LuceneSearcher
The sparse searcher using lucene index, for retrieving doc contents.
"""
DenseVectorPrf.__init__(self)
self.encoder = encoder
self.sparse_searcher = sparse_searcher
def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchResult] = None):
"""Perform single ANCE-PRF with Dense Vectors
Parameters
----------
query : str
query text
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
passage_texts = [query]
for item in prf_candidates:
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
passage_texts.append(raw_text['contents'])
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
emb_q = self.encoder.prf_encode(full_text)
emb_q = emb_q.reshape((1, len(emb_q)))
return emb_q
def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str],
prf_candidates: Dict[str, List[PRFDenseSearchResult]]) -> np.ndarray:
"""Perform batch ANCE-PRF with Dense Vectors
Parameters
----------
topics : List[str]
List of query texts.
topic_ids: List[str]
List of topic ids.
prf_candidates : List[PRFDenseSearchResult]
List of PRFDenseSearchResult, contains document embeddings.
Returns
-------
np.ndarray
return new query embeddings
"""
prf_passage_texts = list()
for index, query in enumerate(topics):
passage_texts = [query]
prf_candidate = prf_candidates[topic_ids[index]]
for item in prf_candidate:
raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw())
passage_texts.append(raw_text['contents'])
full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}'
prf_passage_texts.append(full_text)
emb_q = self.encoder.prf_batch_encode(prf_passage_texts)
return emb_q