Spaces:
Sleeping
Sleeping
File size: 3,448 Bytes
3bf3346 25766a9 4ffa42b 7ca8abc 7ec684a cd97e60 c7ff2c3 7ec684a cd97e60 7ec684a cd97e60 7ec684a c7ff2c3 7ec684a c7ff2c3 7ec684a cd97e60 7ec684a cd97e60 7ec684a cd97e60 c7ff2c3 7ec684a cd97e60 ec4ed0d 25766a9 cd97e60 5f51df4 cd97e60 25766a9 c7ff2c3 cd97e60 25766a9 c7ff2c3 ec4ed0d 5f51df4 7ec684a cd97e60 7ec684a cd97e60 c7ff2c3 cd97e60 ec4ed0d cd97e60 c7ff2c3 ec4ed0d 5f51df4 063bf3b cd97e60 33e4eda 7ff30bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
from sentence_transformers import CrossEncoder
import torch
import requests
import ast
import os
# -------------------------------
# MODELS
# -------------------------------
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
JINA_MODEL = "jina-reranker-m0"
JINA_API_KEY = os.getenv("JINA_API_KEY") # set in HF Space settings
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
NV_MODEL = "NV-RerankQA-Mistral-4B-v3"
HF_API_KEY = os.getenv("HF_API_KEY") # set in HF Space settings
# -------------------------------
# Load models
# -------------------------------
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
# -------------------------------
# Pipeline Function
# -------------------------------
def evaluate_models(query, docs_str):
try:
docs = ast.literal_eval(docs_str)
assert isinstance(docs, list), "Input must be a Python list of strings"
except Exception as e:
return {"Error": f"β οΈ Error parsing documents list: {e}"}
results = {}
# 1. CrossEncoder reranker (MS MARCO)
ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
ce_rerank_scores = [round(torch.sigmoid(torch.tensor(s)).item(), 4) for s in ce_rerank_scores]
results["CrossEncoder (MS MARCO)"] = ce_rerank_scores
# 2. Jina Reranker
if JINA_API_KEY:
headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
payload = {"model": JINA_MODEL, "query": query, "documents": docs}
try:
r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
r.raise_for_status()
jina_scores = [0] * len(docs)
for res in r.json()["results"]:
jina_scores[res["index"]] = round(res["relevance_score"], 4)
results["Jina Reranker"] = jina_scores
except Exception as e:
results["Jina Reranker"] = [f"Error: {e}"]
else:
results["Jina Reranker"] = ["Error: Missing JINA_API_KEY"]
# 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
if HF_API_KEY:
try:
hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
headers = {"Authorization": f"Bearer {HF_API_KEY}"}
payload = {"inputs": {"query": query, "documents": docs}}
r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
r.raise_for_status()
nv_scores = [round(res["score"], 4) for res in r.json()]
results["NV-RerankQA-Mistral-4B-v3"] = nv_scores
except Exception as e:
results["NV-RerankQA-Mistral-4B-v3"] = [f"Error: {e}"]
else:
results["NV-RerankQA-Mistral-4B-v3"] = ["Error: Missing HF_API_KEY"]
return results
# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## π Ranking Battle (Aligned Scores)\nOutputs only **scores aligned to input docs** from 3 models.")
query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
docs = gr.Textbox(
label="Documents (Python list)",
lines=6,
placeholder='Example: [\"Doc one text\", \"Doc two text\", \"Doc three text\"]'
)
out = gr.JSON(label="Model Scores")
btn = gr.Button("Evaluate π")
btn.click(evaluate_models, inputs=[query, docs], outputs=out)
demo.launch()
|