BinKhoaLe1812 commited on
Commit
13f8f13
·
verified ·
1 Parent(s): ec1346d
Files changed (1) hide show
  1. api/retrieval.py +31 -10
api/retrieval.py CHANGED
@@ -181,7 +181,8 @@ class _NvidiaReranker:
181
  """Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
182
  def __init__(self):
183
  self.api_key = os.getenv("NVIDIA_URI")
184
- self.model = "nvidia/rerank-qa-mistral-4b"
 
185
  # NIM rerank endpoint (subject to environment); keep configurable
186
  self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
187
  self.timeout_s = 30
@@ -194,16 +195,36 @@ class _NvidiaReranker:
194
  headers = {
195
  "Authorization": f"Bearer {self.api_key}",
196
  "Content-Type": "application/json",
 
197
  }
198
- payload = {
199
- "model": self.model,
200
- "query": query,
201
- "documents": [{"text": d} for d in documents],
202
- }
 
 
 
 
 
 
 
 
 
 
 
203
  try:
204
- resp = requests.post(self.base_url, headers=headers, json=payload, timeout=self.timeout_s)
205
- resp.raise_for_status()
206
- data = resp.json()
 
 
 
 
 
 
 
 
207
  # Expecting a list with scores and indices or texts
208
  results = []
209
  entries = data.get("results") or data.get("data") or []
@@ -227,4 +248,4 @@ class _NvidiaReranker:
227
  except Exception as e:
228
  logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
229
  # On failure, return original order with neutral scores
230
- return [{"text": d, "score": 0.0} for d in documents]
 
181
  """Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
182
  def __init__(self):
183
  self.api_key = os.getenv("NVIDIA_URI")
184
+ # Use provider doc model identifier
185
+ self.model = os.getenv("NVIDIA_RERANK_MODEL", "nv-rerank-qa-mistral-4b:1")
186
  # NIM rerank endpoint (subject to environment); keep configurable
187
  self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
188
  self.timeout_s = 30
 
195
  headers = {
196
  "Authorization": f"Bearer {self.api_key}",
197
  "Content-Type": "application/json",
198
+ "Accept": "application/json",
199
  }
200
+ # Truncate and limit candidates to avoid 4xx
201
+ docs = documents[:10]
202
+ docs = [d[:2000] for d in docs if isinstance(d, str)]
203
+ # Two payload shapes based on provider doc
204
+ payloads = [
205
+ {
206
+ "model": self.model,
207
+ "query": {"text": query},
208
+ "passages": [{"text": d} for d in docs],
209
+ },
210
+ {
211
+ "model": self.model,
212
+ "query": query,
213
+ "documents": [{"text": d} for d in docs],
214
+ },
215
+ ]
216
  try:
217
+ data = None
218
+ for p in payloads:
219
+ resp = requests.post(self.base_url, headers=headers, json=p, timeout=self.timeout_s)
220
+ if resp.status_code >= 400:
221
+ # try next shape
222
+ continue
223
+ data = resp.json()
224
+ break
225
+ if data is None:
226
+ # last attempt for diagnostics
227
+ resp.raise_for_status()
228
  # Expecting a list with scores and indices or texts
229
  results = []
230
  entries = data.get("results") or data.get("data") or []
 
248
  except Exception as e:
249
  logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
250
  # On failure, return original order with neutral scores
251
+ return [{"text": d, "score": 0.0} for d in documents]