KHome / app.py
hanxu22
add rerank test
1f16925
raw
history blame contribute delete
No virus
1.04 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# 加载rerank模型和tokenizer
model_name = "BAAI/bge-reranker-v2-m3" # 替换为你的rerank模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 定义候选项和查询
query = "What is the capital of France?"
candidates = [
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
"Madrid is the capital of Spain."
]
# 对每个候选项进行打分
scores = []
for candidate in candidates:
inputs = tokenizer(query, candidate, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = model(**inputs).logits
scores.append(logits.item())
# 根据分数对候选项重新排序
ranked_candidates = [x for _, x in sorted(zip(scores, candidates), reverse=True)]
# 输出排序结果
for i, candidate in enumerate(ranked_candidates):
print(f"Rank {i + 1}: {candidate} (Score: {scores[i]})")