z00mP's picture
add reranker
0102b8a
raw history blame
No virus
596 Bytes
from sentence_transformers import CrossEncoder
def rerank_documents(ce_model_name, documents, query, top_k_rerank):
top_k_rerank = int(top_k_rerank)
pairs = []
for doc in documents:
pairs.append((query, doc))
ce_model = CrossEncoder(ce_model_name, max_length=512)
scores = ce_model.predict(pairs)
#sorted_pairs = [(s, x[1]) for s, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
reranked_docs = [x[1] for _, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
return reranked_docs[:top_k_rerank]