omnidesk-ai-test / reranking.py
makcrx
test
05da059
raw
history blame
594 Bytes
from pathlib import Path
from sentence_transformers.cross_encoder import CrossEncoder
from more_itertools import windowed
model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1', max_length=512, device='cpu')
def rerank(sentence_combinations):
similarity_scores = model.predict(sentence_combinations)
scores = [(score_max,idx) for idx,score_max in enumerate(similarity_scores)]
sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
return sim_scores_argsort
def search(query, sentences):
scores = rerank([[query, s] for s in sentences])
return scores[0]