File size: 3,448 Bytes
3bf3346
25766a9
4ffa42b
7ca8abc
7ec684a
cd97e60
c7ff2c3
7ec684a
 
cd97e60
7ec684a
cd97e60
7ec684a
c7ff2c3
7ec684a
c7ff2c3
 
7ec684a
 
cd97e60
7ec684a
cd97e60
7ec684a
cd97e60
 
 
 
 
 
 
 
c7ff2c3
7ec684a
cd97e60
ec4ed0d
25766a9
cd97e60
5f51df4
 
cd97e60
25766a9
c7ff2c3
 
 
 
 
 
 
 
 
 
 
 
 
 
cd97e60
25766a9
c7ff2c3
 
 
 
 
 
 
 
 
 
 
 
 
ec4ed0d
5f51df4
7ec684a
 
cd97e60
7ec684a
cd97e60
c7ff2c3
cd97e60
 
ec4ed0d
cd97e60
 
c7ff2c3
ec4ed0d
5f51df4
063bf3b
cd97e60
 
33e4eda
7ff30bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from sentence_transformers import CrossEncoder
import torch

import requests
import ast
import os

# -------------------------------
# MODELS
# -------------------------------
CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
JINA_MODEL = "jina-reranker-m0"
JINA_API_KEY = os.getenv("JINA_API_KEY")  # set in HF Space settings
JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
NV_MODEL = "NV-RerankQA-Mistral-4B-v3"
HF_API_KEY = os.getenv("HF_API_KEY")  # set in HF Space settings

# -------------------------------
# Load models
# -------------------------------
ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)

# -------------------------------
# Pipeline Function
# -------------------------------
def evaluate_models(query, docs_str):
    try:
        docs = ast.literal_eval(docs_str)
        assert isinstance(docs, list), "Input must be a Python list of strings"
    except Exception as e:
        return {"Error": f"⚠️ Error parsing documents list: {e}"}

    results = {}

    # 1. CrossEncoder reranker (MS MARCO)
    ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
    ce_rerank_scores = [round(torch.sigmoid(torch.tensor(s)).item(), 4) for s in ce_rerank_scores]
    results["CrossEncoder (MS MARCO)"] = ce_rerank_scores

    # 2. Jina Reranker
    if JINA_API_KEY:
        headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
        payload = {"model": JINA_MODEL, "query": query, "documents": docs}
        try:
            r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
            r.raise_for_status()
            jina_scores = [0] * len(docs)
            for res in r.json()["results"]:
                jina_scores[res["index"]] = round(res["relevance_score"], 4)
            results["Jina Reranker"] = jina_scores
        except Exception as e:
            results["Jina Reranker"] = [f"Error: {e}"]
    else:
        results["Jina Reranker"] = ["Error: Missing JINA_API_KEY"]

    # 3. NV RerankQA Mistral-4B-v3 (HF Inference API)
    if HF_API_KEY:
        try:
            hf_endpoint = f"https://api-inference.huggingface.co/models/{NV_MODEL}"
            headers = {"Authorization": f"Bearer {HF_API_KEY}"}
            payload = {"inputs": {"query": query, "documents": docs}}
            r = requests.post(hf_endpoint, headers=headers, json=payload, timeout=60)
            r.raise_for_status()
            nv_scores = [round(res["score"], 4) for res in r.json()]
            results["NV-RerankQA-Mistral-4B-v3"] = nv_scores
        except Exception as e:
            results["NV-RerankQA-Mistral-4B-v3"] = [f"Error: {e}"]
    else:
        results["NV-RerankQA-Mistral-4B-v3"] = ["Error: Missing HF_API_KEY"]

    return results

# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## πŸ‘‘ Ranking Battle (Aligned Scores)\nOutputs only **scores aligned to input docs** from 3 models.")

    query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
    docs = gr.Textbox(
        label="Documents (Python list)", 
        lines=6, 
        placeholder='Example: [\"Doc one text\", \"Doc two text\", \"Doc three text\"]'
    )
    out = gr.JSON(label="Model Scores")

    btn = gr.Button("Evaluate πŸš€")
    btn.click(evaluate_models, inputs=[query, docs], outputs=out)

demo.launch()