import os import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification cross_encoder = None cross_enc_tokenizer = None TOP_K_RERANK = os.getenv("TOP_K_RERANK", 40) @torch.no_grad() def rerank_with_cross_encoder(cross_enc_name, documents, query): if cross_enc_name is None or len(documents) <= 1: return documents global cross_encoder, cross_enc_tokenizer if cross_encoder is None or cross_encoder.name_or_path != cross_enc_name: cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_enc_name) cross_encoder.eval() cross_enc_tokenizer = AutoTokenizer.from_pretrained(cross_enc_name) features = cross_enc_tokenizer( [query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt" ) scores = cross_encoder(**features).logits.squeeze() ranks = torch.argsort(scores, descending=True) documents = [documents[i] for i in ranks[:TOP_K_RERANK]] return documents