|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
def is_str_list(obj): |
|
return isinstance(obj, list) and all(isinstance(item, str) for item in obj) |
|
|
|
def is_np_list(obj): |
|
return isinstance(obj, list) and all(isinstance(item, np.ndarray) for item in obj) |
|
|
|
def is_np_array(obj): |
|
return isinstance(obj, np.ndarray) |
|
|
|
class Sent_Retriever: |
|
def __init__(self, bs=256, use_gpu=True): |
|
self.bs = bs |
|
self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu") |
|
|
|
def embed_passages(self, passages, prefix=""): |
|
if prefix != "": |
|
passages = [prefix + item for item in passages] |
|
embeddings = [] |
|
with torch.no_grad(): |
|
for i in tqdm(range(0, len(passages), self.bs)): |
|
batch_passage = passages[i:(i + self.bs)] |
|
emb = self.model.encode(batch_passage, normalize_embeddings=True) |
|
embeddings.extend(emb) |
|
return embeddings |
|
|
|
def score(self, queries, quotes): |
|
if is_str_list(queries): |
|
query_emb = np.asarray(self.embed_queries(queries)) |
|
elif is_np_list(queries): |
|
query_emb = np.asarray(queries) |
|
elif is_np_array(queries): |
|
query_emb = queries |
|
|
|
if is_str_list(quotes): |
|
quote_emb = np.asarray(self.embed_quotes(quotes)) |
|
elif is_np_list(quotes): |
|
quote_emb = np.asarray(quotes) |
|
elif is_np_array(quotes): |
|
quote_emb = quotes |
|
|
|
return (query_emb @ quote_emb.T).tolist() |
|
|
|
def get_tok_len(self, text_input): |
|
return self.model._first_module().tokenizer( |
|
text=[text_input], |
|
truncation=False, max_length=False, return_tensors="pt" |
|
)["input_ids"].size()[-1] |
|
|
|
|
|
class BGE(Sent_Retriever): |
|
def __init__(self, bs=256, use_gpu=True, model_path="checkpoint/bge-large-en-v1.5"): |
|
from sentence_transformers import SentenceTransformer |
|
super().__init__(bs=bs, use_gpu=use_gpu) |
|
self.model_path = model_path |
|
self.model = SentenceTransformer(self.model_path) |
|
print("[text_wrapper.py - init] Setting up BGE...") |
|
print("[text_wrapper.py - init] BGE is loaded from '{}'...".format( self.model_path )) |
|
self.model.eval() |
|
self.model = self.model.to(self.device) |
|
|
|
def embed_queries(self, queries): |
|
prefix = "Represent this sentence for searching relevant passages:" |
|
if isinstance(queries, str): queries = [queries] |
|
return self.embed_passages(queries, prefix) |
|
|
|
def embed_quotes(self, quotes): |
|
if isinstance(quotes, str): quotes = [quotes] |
|
return self.embed_passages(quotes) |
|
|
|
|
|
class E5(Sent_Retriever): |
|
def __init__(self, bs=256, use_gpu=True, model_path="checkpoint/e5-large-v2"): |
|
from sentence_transformers import SentenceTransformer |
|
super().__init__(bs=bs, use_gpu=use_gpu) |
|
self.model_path = model_path |
|
self.model = SentenceTransformer(self.model_path) |
|
print("[text_wrapper.py - init] Setting up E5...") |
|
print("[text_wrapper.py - init] E5 is loaded from '{}'...".format( self.model_path )) |
|
self.model.eval() |
|
self.model = self.model.to(self.device) |
|
|
|
def embed_queries(self, queries): |
|
prefix = "query:" |
|
if isinstance(queries, str): queries = [queries] |
|
return self.embed_passages(queries, prefix) |
|
|
|
def embed_quotes(self, quotes): |
|
prefix = "passage: " |
|
if isinstance(quotes, str): quotes = [quotes] |
|
return self.embed_passages(quotes, prefix) |
|
|
|
|
|
class GTE(Sent_Retriever): |
|
def __init__(self, bs=256, use_gpu=True, model_path="checkpoint/gte-large"): |
|
from sentence_transformers import SentenceTransformer |
|
super().__init__(bs=bs, use_gpu=use_gpu) |
|
self.model_path = model_path |
|
self.model = SentenceTransformer(self.model_path) |
|
print("[text_wrapper.py - init] Setting up GTE...") |
|
print("[text_wrapper.py - init] GTE is loaded from '{}'...".format( self.model_path )) |
|
self.model.eval() |
|
self.model = self.model.to(self.device) |
|
|
|
def embed_queries(self, queries): |
|
if isinstance(queries, str): queries = [queries] |
|
return self.embed_passages(queries) |
|
|
|
def embed_quotes(self, quotes): |
|
if isinstance(quotes, str): quotes = [quotes] |
|
return self.embed_passages(quotes) |
|
|
|
|
|
class Contriever(): |
|
def __init__(self, bs = 256, use_gpu= True, model_path='checkpoint/contriever-msmarco'): |
|
from transformers import AutoTokenizer, AutoModel |
|
self.model_path = model_path |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
self.model = AutoModel.from_pretrained(self.model_path) |
|
self.bs = bs |
|
self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu") |
|
print("[text_wrapper.py - init] Setting up Contriever...") |
|
print("[text_wrapper.py - init] Contriever is loaded from '{}'...".format( self.model_path )) |
|
self.model.eval() |
|
self.model = self.model.to(self.device) |
|
|
|
def mean_pooling(self, token_embeddings, mask): |
|
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) |
|
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] |
|
return sentence_embeddings |
|
|
|
def embed_queries(self, query): |
|
return self.embed_passages(query) |
|
|
|
def embed_quotes(self, quotes): |
|
return self.embed_passages(quotes) |
|
|
|
def embed_passages(self, quotes): |
|
if isinstance(quotes, str): quotes = [quotes] |
|
quote_embeddings = [] |
|
with torch.no_grad(): |
|
for i in tqdm(range(0, len(quotes), self.bs)): |
|
batch_quotes = quotes[i:(i + self.bs)] |
|
encoded_quotes = self.tokenizer.batch_encode_plus( |
|
batch_quotes, return_tensors = "pt", |
|
max_length = 512, padding = True, truncation = True) |
|
encoded_data = {k: v.to(self.device) for k, v in encoded_quotes.items()} |
|
batched_outputs = self.model(**encoded_data) |
|
batched_quote_embs = self.mean_pooling(batched_outputs[0], encoded_data['attention_mask']) |
|
quote_embeddings.extend([q.cpu().detach().numpy() for q in batched_quote_embs]) |
|
return quote_embeddings |
|
|
|
def score(self, queries, quotes): |
|
if is_str_list(queries): |
|
query_emb = np.asarray(self.embed_queries(queries)) |
|
elif is_np_list(queries): |
|
query_emb = np.asarray(queries) |
|
elif is_np_array(queries): |
|
query_emb = queries |
|
|
|
if is_str_list(quotes): |
|
quote_emb = np.asarray(self.embed_quotes(quotes)) |
|
elif is_np_list(quotes): |
|
quote_emb = np.asarray(quotes) |
|
elif is_np_array(quotes): |
|
quote_emb = quotes |
|
|
|
return (query_emb @ quote_emb.T).tolist() |
|
|
|
|
|
class DPR(): |
|
def __init__(self, bs = 256, use_gpu=True, model_path="checkpoint"): |
|
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer |
|
self.model_path = model_path |
|
self.query_tok = DPRQuestionEncoderTokenizer.from_pretrained(self.model_path +"/dpr-question_encoder-multiset-base") |
|
self.query_enc = DPRQuestionEncoder.from_pretrained(self.model_path +"/dpr-question_encoder-multiset-base") |
|
self.ctx_tok = DPRContextEncoderTokenizer.from_pretrained(self.model_path +"/dpr-ctx_encoder-multiset-base") |
|
self.ctx_enc = DPRContextEncoder.from_pretrained(self.model_path +"/dpr-ctx_encoder-multiset-base") |
|
self.bs = bs |
|
print("[text_wrapper.py - init] Setting up DPR...") |
|
print("[text_wrapper.py - init] DPR is loaded from '{}'...".format( self.model_path )) |
|
self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu") |
|
self.query_enc.eval() |
|
self.query_enc = self.query_enc.to(self.device) |
|
self.ctx_enc.eval() |
|
self.ctx_enc = self.ctx_enc.to(self.device) |
|
|
|
def embed_queries(self, queries): |
|
if isinstance(queries, str): queries = [queries] |
|
query_embeddings = [] |
|
with torch.no_grad(): |
|
for i in tqdm(range(0, len(queries), self.bs)): |
|
batch_queries = queries[i:(i + self.bs)] |
|
encoded_query = self.query_tok.batch_encode_plus( |
|
batch_queries, truncation=True, padding=True, |
|
return_tensors='pt', max_length=512) |
|
encoded_data = {k : v.cuda() for k, v in encoded_query.items()} |
|
query_emb = self.query_enc(**encoded_data).pooler_output |
|
query_emb = [q.cpu().detach().numpy() for q in query_emb] |
|
query_embeddings.extend(query_emb) |
|
return query_embeddings |
|
|
|
def embed_quotes(self, quotes): |
|
if isinstance(quotes, str): quotes = [quotes] |
|
quote_embeddings = [] |
|
with torch.no_grad(): |
|
for i in tqdm(range(0, len(quotes), self.bs)): |
|
batch_quotes = quotes[i:(i + self.bs)] |
|
encoded_ctx = self.ctx_tok.batch_encode_plus( |
|
batch_quotes, truncation=True, padding=True, |
|
return_tensors='pt', max_length=512) |
|
encoded_data = {k: v.cuda() for k, v in encoded_ctx.items()} |
|
quote_emb = self.ctx_enc(**encoded_data).pooler_output |
|
quote_emb = [q.cpu().detach().numpy() for q in quote_emb] |
|
quote_embeddings.extend(quote_emb) |
|
return quote_embeddings |
|
|
|
def score(self, queries, quotes): |
|
if is_str_list(queries): |
|
query_emb = np.asarray(self.embed_queries(queries)) |
|
elif is_np_list(queries): |
|
query_emb = np.asarray(queries) |
|
elif is_np_array(queries): |
|
query_emb = queries |
|
|
|
if is_str_list(quotes): |
|
quote_emb = np.asarray(self.embed_quotes(quotes)) |
|
elif is_np_list(quotes): |
|
quote_emb = np.asarray(quotes) |
|
elif is_np_array(quotes): |
|
quote_emb = quotes |
|
|
|
return (query_emb @ quote_emb.T).tolist() |
|
|
|
|
|
class ColBERTReranker: |
|
def __init__(self, bs = 256, use_gpu= True, model_path="checkpoint/colbertv2.0"): |
|
from colbert.modeling.colbert import ColBERT |
|
from colbert.infra import ColBERTConfig |
|
from transformers import AutoTokenizer |
|
self.model_path = model_path |
|
self.bs = bs |
|
config = ColBERTConfig(bsize=bs, root='./', query_token_id='[Q]', doc_token_id='[D]') |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
self.model = ColBERT(name=self.model_path, colbert_config=config) |
|
self.doc_token_id = self.tokenizer.convert_tokens_to_ids(config.doc_token_id) |
|
self.query_token_id = self.tokenizer.convert_tokens_to_ids(config.query_token_id) |
|
self.add_special_tokens = True |
|
self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu") |
|
print("[text_wrapper.py - init] Setting up ColBERT Reranker...") |
|
print("[text_wrapper.py - init] ColBERT is loaded from '{}'...".format( self.model_path )) |
|
self.model.eval() |
|
self.model = self.model.to(self.device) |
|
|
|
def embed_queries(self, queries): |
|
if isinstance(queries, str): queries = [queries] |
|
query_embeddings = [] |
|
query = ['. ' + item for item in queries] |
|
with torch.no_grad(): |
|
for i in tqdm(range(0, len(queries), self.bs)): |
|
batch_queries = queries[i:(i + self.bs)] |
|
encoded_query = self.tokenizer.batch_encode_plus( |
|
batch_queries, max_length = 512, padding=True, truncation=True, |
|
add_special_tokens=self.add_special_tokens, return_tensors='pt') |
|
encoded_data = {k: v.to(self.device) for k, v in encoded_query.items()} |
|
encoded_data['input_ids'][:, 1] = self.query_token_id |
|
batch_query_emb = self.model.query(encoded_data['input_ids'], encoded_data['attention_mask']) |
|
|
|
for emb, mask in zip(batch_query_emb, encoded_data['attention_mask']): |
|
length = mask.sum().item() |
|
np_emb = emb[:length].cpu().numpy() |
|
query_embeddings.append(np_emb) |
|
return query_embeddings |
|
|
|
@staticmethod |
|
def pad_tok_len(quote_embeddings, pad_value=0): |
|
lengths = [e.shape[0] for e in quote_embeddings] |
|
max_len = max(lengths) |
|
N, H = len(quote_embeddings), quote_embeddings[0].shape[1] |
|
padded_embeddings = np.full((N, max_len, H), pad_value, dtype=quote_embeddings[0].dtype) |
|
padded_masks = np.zeros((N, max_len), dtype=np.int64) |
|
for i, (emb, length) in enumerate(zip(quote_embeddings, lengths)): |
|
padded_embeddings[i, :length, :] = emb |
|
padded_masks[i, :length] = 1 |
|
return padded_embeddings, padded_masks |
|
|
|
def embed_quotes(self, quotes, pad_token_len = False): |
|
quote_embeddings = [] |
|
quote_masks = [] |
|
quotes = ['. ' + quote for quote in quotes] |
|
with torch.no_grad(): |
|
|
|
for i in tqdm(range(0, len(quotes), self.bs)): |
|
batch_quotes = quotes[i:(i + self.bs)] |
|
encoded_quotes = self.tokenizer.batch_encode_plus( |
|
batch_quotes, return_tensors = "pt", |
|
max_length = 512, padding = True, truncation = True) |
|
encoded_data = {k: v.to(self.device) for k, v in encoded_quotes.items()} |
|
encoded_data['input_ids'][:, 1] = self.doc_token_id |
|
|
|
batched_quote_embs = self.model.doc(encoded_data['input_ids'], encoded_data['attention_mask']) |
|
|
|
for emb, mask in zip(batched_quote_embs, encoded_data['attention_mask']): |
|
length = mask.sum().item() |
|
np_emb = emb[:length].cpu().numpy() |
|
quote_embeddings.append(np_emb) |
|
|
|
|
|
if pad_token_len: |
|
quote_embeddings, quote_masks = self.pad_tok_len(quote_embeddings) |
|
return quote_embeddings, quote_masks |
|
return quote_embeddings |
|
|
|
@staticmethod |
|
def colbert_score(query_embed, quote_embeddings, quote_masks): |
|
Q, H = query_embed.shape |
|
N, L, _ = quote_embeddings.shape |
|
query_expanded = query_embed[:, np.newaxis, np.newaxis, :] |
|
quote_expanded = quote_embeddings[np.newaxis, :, :, :] |
|
sim = np.matmul(query_expanded, np.transpose(quote_expanded, (0 ,1 ,3 ,2))) |
|
sim = np.einsum('qh,nlh->qnl', query_embed, quote_embeddings) |
|
sim = np.where(quote_masks[np.newaxis, :, : ]==1, sim, -1e9) |
|
maxsim = sim.max(-1) |
|
scores = maxsim.sum(axis=0) |
|
return scores |
|
|
|
def score(self, queries, quotes): |
|
if is_str_list(queries): |
|
query_embed = self.embed_queries(queries) |
|
elif is_np_list(queries): |
|
query_embed = queries |
|
|
|
if is_str_list(quotes): |
|
quote_embed, quote_masks = self.embed_quotes(quotes, pad_token_len=True) |
|
elif is_np_list(quotes): |
|
quote_embed, quote_masks = self.pad_tok_len(quotes) |
|
|
|
scores_list = [] |
|
for q_embed in query_embed: |
|
scores = self.colbert_score(q_embed, quote_embed, quote_masks) |
|
scores_list.append(scores.tolist()) |
|
return scores_list |