File size: 596 Bytes
0102b8a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from sentence_transformers import CrossEncoder


def rerank_documents(ce_model_name, documents, query, top_k_rerank):
    top_k_rerank = int(top_k_rerank)
    pairs = []
    for doc in documents:
        pairs.append((query, doc))
        ce_model = CrossEncoder(ce_model_name, max_length=512)
        scores = ce_model.predict(pairs)
        #sorted_pairs = [(s, x[1]) for s, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
        reranked_docs = [x[1] for _, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
        return reranked_docs[:top_k_rerank]