lukestanley commited on
Commit
1e622b4
1 Parent(s): 135f3ac

Add retry logic upon schema fail for Mistral API calls

Browse files
Files changed (1) hide show
  1. utils.py +19 -5
utils.py CHANGED
@@ -190,7 +190,7 @@ def llm_stream_serverless(prompt,model):
190
  LAST_REQUEST_TIME = None
191
  REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
192
 
193
- def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict[str, Any]]:
194
  global LAST_REQUEST_TIME
195
  current_time = time()
196
  if LAST_REQUEST_TIME is not None:
@@ -227,10 +227,24 @@ def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict
227
  print(result)
228
  output = result['choices'][0]['message']['content']
229
  if pydantic_model_class:
230
- parsed_result = pydantic_model_class.model_validate_json(output)
231
- print(parsed_result)
232
- # This will raise an exception if the model is invalid,
233
- # TODO: handle exception with retry logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  else:
235
  print("No pydantic model class provided, returning without class validation")
236
  return json.loads(output)
 
190
  LAST_REQUEST_TIME = None
191
  REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
192
 
193
+ def llm_stream_mistral_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]:
194
  global LAST_REQUEST_TIME
195
  current_time = time()
196
  if LAST_REQUEST_TIME is not None:
 
227
  print(result)
228
  output = result['choices'][0]['message']['content']
229
  if pydantic_model_class:
230
+ # TODO: Use more robust error handling that works for all cases without retrying?
231
+ # Maybe APIs that dont have grammar should be avoided?
232
+ # Investigate grammar enforcement with open ended generations?
233
+ try:
234
+ parsed_result = pydantic_model_class.model_validate_json(output)
235
+ print(parsed_result)
236
+ # This will raise an exception if the model is invalid,
237
+ except Exception as e:
238
+ print(f"Error validating pydantic model: {e}")
239
+ # Let's retry by calling ourselves again if attempts < 3
240
+ if attempts == 0:
241
+ # We modify the prompt to remind it to output JSON in the required format
242
+ prompt = f"{prompt} You must output the JSON in the required format!"
243
+ if attempts < 3:
244
+ attempts += 1
245
+ print(f"Retrying Mistral API call, attempt {attempts}")
246
+ return llm_stream_mistral_api(prompt, pydantic_model_class, attempts)
247
+
248
  else:
249
  print("No pydantic model class provided, returning without class validation")
250
  return json.loads(output)