from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# 1. Reranker ๋ชจ๋ธ ๋ก๋ฉ | |
reranker_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-base") | |
reranker_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-base") | |
def rerank_documents(query: str, docs: list, top_k: int = 5) -> list: | |
""" | |
๊ฒ์๋ ๋ฌธ์ ๋ฆฌ์คํธ๋ฅผ Query์ ๋น๊ตํด์ relevance ์์๋ก ์ฌ์ ๋ ฌํ๋ค. | |
""" | |
pairs = [(query, doc) for doc in docs] | |
inputs = reranker_tokenizer.batch_encode_plus( | |
pairs, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512 | |
) | |
with torch.no_grad(): | |
scores = reranker_model(**inputs).logits.squeeze(-1) # (batch_size,) | |
scores = scores.tolist() | |
# ์ ์ ๋์ ์์๋๋ก ์ ๋ ฌ | |
sorted_docs = [doc for _, doc in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)] | |
return sorted_docs[:top_k] | |