import numpy as np from typing import List, Dict from pyserini.search.faiss import PRFDenseSearchResult, AnceQueryEncoder from pyserini.search.lucene import LuceneSearcher import json class DenseVectorPrf: def __init__(self): pass def get_prf_q_emb(self, **kwargs): pass def get_batch_prf_q_emb(self, **kwargs): pass class DenseVectorAveragePrf(DenseVectorPrf): def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): """Perform Average PRF with Dense Vectors Parameters ---------- emb_qs : np.ndarray Query embedding prf_candidates : List[PRFDenseSearchResult] List of PRFDenseSearchResult, contains document embeddings. Returns ------- np.ndarray return new query embeddings """ all_candidate_embs = [item.vectors for item in prf_candidates] new_emb_qs = np.mean(np.vstack((emb_qs, all_candidate_embs)), axis=0) return new_emb_qs def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None): """Perform Average PRF with Dense Vectors Parameters ---------- topic_ids : List[str] List of topic ids. emb_qs : np.ndarray Query embeddings prf_candidates : List[PRFDenseSearchResult] List of PRFDenseSearchResult, contains document embeddings. Returns ------- np.ndarray return new query embeddings """ qids = list() new_emb_qs = list() for index, topic_id in enumerate(topic_ids): qids.append(topic_id) new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id])) new_emb_qs = np.array(new_emb_qs).astype('float32') return new_emb_qs class DenseVectorRocchioPrf(DenseVectorPrf): def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: int): """ Parameters ---------- alpha : float Rocchio parameter, controls the weight assigned to the original query embedding. beta : float Rocchio parameter, controls the weight assigned to the positive document embeddings. gamma : float Rocchio parameter, controls the weight assigned to the negative document embeddings. topk : int Rocchio parameter, set topk documents as positive document feedbacks. bottomk : int Rocchio parameter, set bottomk documents as negative document feedbacks. """ DenseVectorPrf.__init__(self) self.alpha = alpha self.beta = beta self.gamma = gamma self.topk = topk self.bottomk = bottomk def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): """Perform Rocchio PRF with Dense Vectors Parameters ---------- emb_qs : np.ndarray query embedding prf_candidates : List[PRFDenseSearchResult] List of PRFDenseSearchResult, contains document embeddings. Returns ------- np.ndarray return new query embeddings """ all_candidate_embs = [item.vectors for item in prf_candidates] weighted_query_embs = self.alpha * emb_qs weighted_mean_pos_doc_embs = self.beta * np.mean(all_candidate_embs[:self.topk], axis=0) new_emb_q = weighted_query_embs + weighted_mean_pos_doc_embs if self.bottomk > 0: weighted_mean_neg_doc_embs = self.gamma * np.mean(all_candidate_embs[-self.bottomk:], axis=0) new_emb_q -= weighted_mean_neg_doc_embs return new_emb_q def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None): """Perform Rocchio PRF with Dense Vectors Parameters ---------- topic_ids : List[str] List of topic ids. emb_qs : np.ndarray Query embeddings prf_candidates : List[PRFDenseSearchResult] List of PRFDenseSearchResult, contains document embeddings. Returns ------- np.ndarray return new query embeddings """ qids = list() new_emb_qs = list() for index, topic_id in enumerate(topic_ids): qids.append(topic_id) new_emb_qs.append(self.get_prf_q_emb(emb_qs[index], prf_candidates[topic_id])) new_emb_qs = np.array(new_emb_qs).astype('float32') return new_emb_qs class DenseVectorAncePrf(DenseVectorPrf): def __init__(self, encoder: AnceQueryEncoder, sparse_searcher: LuceneSearcher): """ Parameters ---------- encoder : AnceQueryEncoder The new ANCE query encoder for ANCE-PRF. sparse_searcher : LuceneSearcher The sparse searcher using lucene index, for retrieving doc contents. """ DenseVectorPrf.__init__(self) self.encoder = encoder self.sparse_searcher = sparse_searcher def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchResult] = None): """Perform single ANCE-PRF with Dense Vectors Parameters ---------- query : str query text prf_candidates : List[PRFDenseSearchResult] List of PRFDenseSearchResult, contains document embeddings. Returns ------- np.ndarray return new query embeddings """ passage_texts = [query] for item in prf_candidates: raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw()) passage_texts.append(raw_text['contents']) full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}' emb_q = self.encoder.prf_encode(full_text) emb_q = emb_q.reshape((1, len(emb_q))) return emb_q def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str], prf_candidates: Dict[str, List[PRFDenseSearchResult]]) -> np.ndarray: """Perform batch ANCE-PRF with Dense Vectors Parameters ---------- topics : List[str] List of query texts. topic_ids: List[str] List of topic ids. prf_candidates : List[PRFDenseSearchResult] List of PRFDenseSearchResult, contains document embeddings. Returns ------- np.ndarray return new query embeddings """ prf_passage_texts = list() for index, query in enumerate(topics): passage_texts = [query] prf_candidate = prf_candidates[topic_ids[index]] for item in prf_candidate: raw_text = json.loads(self.sparse_searcher.doc(item.docid).raw()) passage_texts.append(raw_text['contents']) full_text = f'{self.encoder.tokenizer.cls_token}{self.encoder.tokenizer.sep_token.join(passage_texts)}{self.encoder.tokenizer.sep_token}' prf_passage_texts.append(full_text) emb_q = self.encoder.prf_batch_encode(prf_passage_texts) return emb_q