canap / ranker.py
mnemlaghi's picture
last files for commit
b3f3132
raw
history blame
4.99 kB
from transformers import AutoModel, AutoTokenizer
import pandas as pd
import torch
from torch.utils.data import Dataset
import logging
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pickle
import string
from abc import abstractmethod
import json
class AbstractMoviesRanker:
"""Abstract class for ranking items"""
def __init__(self, df, index_matrix, score_name = "score"):
self.df = df
self.ids = self.df.index.values
self.index_matrix = index_matrix
self.score_name = score_name
@abstractmethod
def encode_query(self, query):
pass
def get_scores(self, encoded_query):
return torch.mm(encoded_query, self.index_matrix.transpose(0,1))[0].tolist()
def get_top_ids(self, scores, topn=6):
ids_scores_pairs = list(zip(self.ids.tolist(), scores))
ids_scores_pairs = sorted(ids_scores_pairs, key = lambda x:x[1], reverse = True)
sorted_ids = [v[0] for v in ids_scores_pairs]
sorted_scores = [v[1] for v in ids_scores_pairs]
sorted_df = self.df.loc[sorted_ids[:topn], :]
sorted_df.loc[:,self.score_name] = sorted_scores[:topn]
return sorted_df
def run_query(self, query, topn=6):
encoded_query = self.encode_query(query)
scores = self.get_scores(encoded_query)
return self.get_top_ids(scores, topn)
depunctuate = staticmethod(lambda x: x.translate(str.maketrans('','',string.punctuation)))
class SparseTfIdfRanker(AbstractMoviesRanker):
"""Sparse Ranking via TF iDF"""
def __init__(self, df, index_matrix, vectorizer_path):
super(SparseTfIdfRanker, self).__init__(df, index_matrix, score_name = 'tfidf-score')
self.vectorizer = pickle.load(open(vectorizer_path, 'rb'))
self.index_matrix = self.index_matrix.to_dense() ##For dot products
def encode_query(self, query):
encoded_query = torch.tensor(self.vectorizer.transform([self.depunctuate(query)]).todense(), dtype = torch.float32)
return F.normalize(encoded_query, p=2)
class BertRanker(AbstractMoviesRanker):
"""Dense Ranking with embedding matrix"""
def __init__(self, df, index_matrix, modelpath):
super(BertRanker, self).__init__(df, index_matrix, score_name = "bert-score")
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
self.model = AutoModel.from_pretrained(modelpath)
def encode_query(self, query):
tok_q = self.tokenizer(query, return_tensors="pt", padding="max_length", max_length = 128, truncation=True)
o = self.model(**tok_q)
encoded_query = self.mean_pooling(o, tok_q['attention_mask'])
return F.normalize(encoded_query, p=2)
@staticmethod
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class SparseDenseMoviesRanker():
"""Sparse Ranking via TF iDF, filtering a first rank, then dense ranking on these items"""
def __init__(self, df, modelpath, bert_index, sparse_index, vectorizer_path):
self.df =df
self.ids = self.df.index.values
self.tfidf_engine = SparseTfIdfRanker(df, sparse_index, vectorizer_path)
self.modelpath = modelpath
self.bert_index = bert_index
def run_query(self, query, topn=6, first_ranking=1000):
tfidf_sorted_frame = self.tfidf_engine.run_query(query, topn=first_ranking)
firstranking_index = self.bert_index[tfidf_sorted_frame.index.values]
self.bert_engine = BertRanker(tfidf_sorted_frame, firstranking_index, self.modelpath)
bert_sorted_frame = self.bert_engine.run_query(query, topn=topn)
return bert_sorted_frame
@classmethod
def from_json_config(cls, jsonfile):
with open(jsonfile) as fp:
conf = json.loads(fp.read())
##Load data for ranking
df = pd.read_pickle(conf['dataframe'])
##Load indices, e.g. embeddings and encoding utilities
bert_index = torch.load(conf['bert_index'])
sparse_index = torch.load(conf['sparse_index'])
vectorizer_path = conf['vectorizer_path']
modelpath = conf['modelpath']
##Conf for first ranking
firstranking = conf.get('firstranking', 100)
ranker = cls(df, modelpath, bert_index, sparse_index, vectorizer_path)
return ranker
if __name__=='__main__':
engine = SparseDenseMoviesRanker.from_json_config('conf.json')
for query in ["une histoire de pirates et de chasse au trésor", "une histoire de gangsters avec de l'argent"]:
print(query)
final_df = engine.run_query(query)
print(final_df.head())