afriddev commited on
Commit
cd97e60
Β·
verified Β·
1 Parent(s): ec4ed0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -51
app.py CHANGED
@@ -1,82 +1,100 @@
1
  import gradio as gr
2
- from sentence_transformers import CrossEncoder
3
  import torch
4
  import requests
 
5
 
6
  # -------------------------------
7
- # CONFIG
8
  # -------------------------------
9
- HF_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
 
 
 
10
  JINA_MODEL = "jina-reranker-m0"
11
  JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
12
  JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
13
 
14
  # -------------------------------
15
- # Load Hugging Face CrossEncoder
16
  # -------------------------------
17
- hf_model = CrossEncoder(HF_MODEL)
 
 
 
18
 
19
- def rerank(query, docs_text):
20
- # Split input documents (one per line)
21
- docs = [d.strip() for d in docs_text.split("\n") if d.strip()]
22
- if not docs:
23
- return "⚠️ No documents provided."
 
 
 
 
 
24
 
25
- # -------------------------------
26
- # Hugging Face CrossEncoder Scores
27
- # -------------------------------
28
- hf_scores = hf_model.predict([(query, d) for d in docs])
29
- hf_scores = [torch.sigmoid(torch.tensor(s)).item() for s in hf_scores]
30
- hf_ranking = sorted(zip(docs, hf_scores), key=lambda x: x[1], reverse=True)
31
 
32
- # -------------------------------
33
- # Jina Reranker API Scores
34
- # -------------------------------
35
- headers = {
36
- "Authorization": f"Bearer {JINA_API_KEY}",
37
- "Content-Type": "application/json",
38
- }
39
- payload = {
40
- "model": JINA_MODEL,
41
- "query": query,
42
- "documents": docs,
43
- }
 
 
44
  try:
45
- r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=20)
46
  r.raise_for_status()
47
- results = r.json()["results"]
48
- jina_scores = [res["relevance_score"] for res in results]
49
- jina_ranking = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
50
  except Exception as e:
51
- jina_ranking = [("Error", str(e))]
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # -------------------------------
54
  # Format output
55
  # -------------------------------
56
- out = "### Hugging Face Ranking\n"
57
- for doc, score in hf_ranking:
58
- out += f"- ({score:.4f}) {doc}\n"
59
-
60
- out += "\n### Jina Reranker Ranking\n"
61
- for doc, score in jina_ranking:
62
- out += f"- ({score}) {doc}\n"
63
-
64
  return out
65
 
66
  # -------------------------------
67
- # Simple UI
68
  # -------------------------------
69
- with gr.Blocks() as demo:
70
- gr.Markdown("### πŸ”Ž Query + Multiple Docs Reranking (HF vs Jina)")
71
- query = gr.Textbox(label="Query", lines=2, placeholder="Enter your query here...")
 
72
  docs = gr.Textbox(
73
- label="Candidate Documents (one per line)",
74
- lines=10,
75
- placeholder="Paste multiple document chunks here, each on a new line..."
76
  )
77
- out = gr.Textbox(label="Ranked Results", lines=15)
78
 
79
- btn = gr.Button("Rerank πŸš€")
80
- btn.click(rerank, inputs=[query, docs], outputs=out)
81
 
82
  demo.launch()
 
1
  import gradio as gr
2
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
3
  import torch
4
  import requests
5
+ import ast
6
 
7
  # -------------------------------
8
+ # MODELS
9
  # -------------------------------
10
+ BI_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
11
+ CROSS_ENCODER_RERANK = "cross-encoder/ms-marco-MiniLM-L-12-v2"
12
+ CROSS_ENCODER_STS = "cross-encoder/stsb-roberta-large"
13
+ CROSS_ENCODER_NLI = "cross-encoder/nli-deberta-v3-base"
14
  JINA_MODEL = "jina-reranker-m0"
15
  JINA_API_KEY = "jina_4075150fa702471c85ddea0a9ad4b306ouE7ymhrCpvxTxX3mScUv5LLDPKQ"
16
  JINA_ENDPOINT = "https://api.jina.ai/v1/rerank"
17
 
18
  # -------------------------------
19
+ # Load models
20
  # -------------------------------
21
+ bi_encoder = SentenceTransformer(BI_ENCODER)
22
+ ce_rerank = CrossEncoder(CROSS_ENCODER_RERANK)
23
+ ce_sts = CrossEncoder(CROSS_ENCODER_STS)
24
+ ce_nli = CrossEncoder(CROSS_ENCODER_NLI, num_labels=3)
25
 
26
+ # -------------------------------
27
+ # Pipeline Function
28
+ # -------------------------------
29
+ def evaluate_models(query, docs_str):
30
+ try:
31
+ # Parse docs string as Python list
32
+ docs = ast.literal_eval(docs_str)
33
+ assert isinstance(docs, list), "Input must be a Python list of strings"
34
+ except Exception as e:
35
+ return f"⚠️ Error parsing documents list: {e}"
36
 
37
+ results = {}
 
 
 
 
 
38
 
39
+ # 1. Bi-encoder cosine similarity
40
+ query_emb = bi_encoder.encode(query, convert_to_tensor=True)
41
+ doc_embs = bi_encoder.encode(docs, convert_to_tensor=True)
42
+ cos_scores = util.cos_sim(query_emb, doc_embs)[0].cpu().tolist()
43
+ results["1. Bi-encoder similarity"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)
44
+
45
+ # 2. CrossEncoder reranker (MS MARCO)
46
+ ce_rerank_scores = ce_rerank.predict([(query, d) for d in docs])
47
+ ce_rerank_scores = [torch.sigmoid(torch.tensor(s)).item() for s in ce_rerank_scores]
48
+ results["2. CrossEncoder Reranker (MS MARCO)"] = sorted(zip(docs, ce_rerank_scores), key=lambda x: x[1], reverse=True)
49
+
50
+ # 3. Jina Reranker
51
+ headers = {"Authorization": f"Bearer {JINA_API_KEY}", "Content-Type": "application/json"}
52
+ payload = {"model": JINA_MODEL, "query": query, "documents": docs}
53
  try:
54
+ r = requests.post(JINA_ENDPOINT, headers=headers, json=payload, timeout=30)
55
  r.raise_for_status()
56
+ jina_scores = [res["relevance_score"] for res in r.json()["results"]]
57
+ results["3. Jina Reranker"] = sorted(zip(docs, jina_scores), key=lambda x: x[1], reverse=True)
 
58
  except Exception as e:
59
+ results["3. Jina Reranker"] = [(f"Error: {e}", 0)]
60
+
61
+ # 4. CrossEncoder STS
62
+ ce_sts_scores = ce_sts.predict([(query, d) for d in docs])
63
+ results["4. CrossEncoder STS"] = sorted(zip(docs, ce_sts_scores), key=lambda x: x[1], reverse=True)
64
+
65
+ # 5. CrossEncoder NLI
66
+ ce_nli_probs = ce_nli.predict([(query, d) for d in docs], apply_softmax=True)
67
+ ce_nli_scores = [float(p[1] + p[2]) for p in ce_nli_probs] # neutral + entailment
68
+ results["5. CrossEncoder NLI"] = sorted(zip(docs, ce_nli_scores), key=lambda x: x[1], reverse=True)
69
+
70
+ # 6. Bi-encoder raw similarity (duplicate for clarity)
71
+ results["6. Bi-encoder baseline"] = sorted(zip(docs, cos_scores), key=lambda x: x[1], reverse=True)
72
 
73
  # -------------------------------
74
  # Format output
75
  # -------------------------------
76
+ out = ""
77
+ for model_name, ranked in results.items():
78
+ out += f"\n### {model_name}\n"
79
+ for doc, score in ranked:
80
+ out += f"- ({round(score,4)}) {doc}\n"
 
 
 
81
  return out
82
 
83
  # -------------------------------
84
+ # Gradio UI
85
  # -------------------------------
86
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
+ gr.Markdown("## πŸ”Ž Multi-Model Reranker (HF + Jina)\nPass a **query** and a **list of documents (Python list of strings)**.")
88
+
89
+ query = gr.Textbox(label="Query", lines=2, placeholder="Enter your search query...")
90
  docs = gr.Textbox(
91
+ label="Documents (Python list)",
92
+ lines=6,
93
+ placeholder='Example: ["Doc one text", "Doc two text", "Doc three text"]'
94
  )
95
+ out = gr.Textbox(label="Ranked Results", lines=20)
96
 
97
+ btn = gr.Button("Evaluate πŸš€")
98
+ btn.click(evaluate_models, inputs=[query, docs], outputs=out)
99
 
100
  demo.launch()