gemma / retriever /reranker.py
dasomaru's picture
Upload folder using huggingface_hub
9b14ff1 verified
raw
history blame contribute delete
1.04 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# 1. Reranker ๋ชจ๋ธ ๋กœ๋”ฉ
reranker_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-base")
reranker_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-base")
def rerank_documents(query: str, docs: list, top_k: int = 5) -> list:
"""
๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ๋ฅผ Query์™€ ๋น„๊ตํ•ด์„œ relevance ์ˆœ์„œ๋กœ ์žฌ์ •๋ ฌํ•œ๋‹ค.
"""
pairs = [(query, doc) for doc in docs]
inputs = reranker_tokenizer.batch_encode_plus(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
)
with torch.no_grad():
scores = reranker_model(**inputs).logits.squeeze(-1) # (batch_size,)
scores = scores.tolist()
# ์ ์ˆ˜ ๋†’์€ ์ˆœ์„œ๋Œ€๋กœ ์ •๋ ฌ
sorted_docs = [doc for _, doc in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)]
return sorted_docs[:top_k]