Spaces:
Sleeping
Sleeping
Upd api
Browse files- api/retrieval.py +31 -10
api/retrieval.py
CHANGED
|
@@ -181,7 +181,8 @@ class _NvidiaReranker:
|
|
| 181 |
"""Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
|
| 182 |
def __init__(self):
|
| 183 |
self.api_key = os.getenv("NVIDIA_URI")
|
| 184 |
-
|
|
|
|
| 185 |
# NIM rerank endpoint (subject to environment); keep configurable
|
| 186 |
self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
|
| 187 |
self.timeout_s = 30
|
|
@@ -194,16 +195,36 @@ class _NvidiaReranker:
|
|
| 194 |
headers = {
|
| 195 |
"Authorization": f"Bearer {self.api_key}",
|
| 196 |
"Content-Type": "application/json",
|
|
|
|
| 197 |
}
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
try:
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
# Expecting a list with scores and indices or texts
|
| 208 |
results = []
|
| 209 |
entries = data.get("results") or data.get("data") or []
|
|
@@ -227,4 +248,4 @@ class _NvidiaReranker:
|
|
| 227 |
except Exception as e:
|
| 228 |
logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
|
| 229 |
# On failure, return original order with neutral scores
|
| 230 |
-
return [{"text": d, "score": 0.0} for d in documents]
|
|
|
|
| 181 |
"""Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
|
| 182 |
def __init__(self):
|
| 183 |
self.api_key = os.getenv("NVIDIA_URI")
|
| 184 |
+
# Use provider doc model identifier
|
| 185 |
+
self.model = os.getenv("NVIDIA_RERANK_MODEL", "nv-rerank-qa-mistral-4b:1")
|
| 186 |
# NIM rerank endpoint (subject to environment); keep configurable
|
| 187 |
self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
|
| 188 |
self.timeout_s = 30
|
|
|
|
| 195 |
headers = {
|
| 196 |
"Authorization": f"Bearer {self.api_key}",
|
| 197 |
"Content-Type": "application/json",
|
| 198 |
+
"Accept": "application/json",
|
| 199 |
}
|
| 200 |
+
# Truncate and limit candidates to avoid 4xx
|
| 201 |
+
docs = documents[:10]
|
| 202 |
+
docs = [d[:2000] for d in docs if isinstance(d, str)]
|
| 203 |
+
# Two payload shapes based on provider doc
|
| 204 |
+
payloads = [
|
| 205 |
+
{
|
| 206 |
+
"model": self.model,
|
| 207 |
+
"query": {"text": query},
|
| 208 |
+
"passages": [{"text": d} for d in docs],
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"model": self.model,
|
| 212 |
+
"query": query,
|
| 213 |
+
"documents": [{"text": d} for d in docs],
|
| 214 |
+
},
|
| 215 |
+
]
|
| 216 |
try:
|
| 217 |
+
data = None
|
| 218 |
+
for p in payloads:
|
| 219 |
+
resp = requests.post(self.base_url, headers=headers, json=p, timeout=self.timeout_s)
|
| 220 |
+
if resp.status_code >= 400:
|
| 221 |
+
# try next shape
|
| 222 |
+
continue
|
| 223 |
+
data = resp.json()
|
| 224 |
+
break
|
| 225 |
+
if data is None:
|
| 226 |
+
# last attempt for diagnostics
|
| 227 |
+
resp.raise_for_status()
|
| 228 |
# Expecting a list with scores and indices or texts
|
| 229 |
results = []
|
| 230 |
entries = data.get("results") or data.get("data") or []
|
|
|
|
| 248 |
except Exception as e:
|
| 249 |
logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
|
| 250 |
# On failure, return original order with neutral scores
|
| 251 |
+
return [{"text": d, "score": 0.0} for d in documents]
|