lukestanley commited on
Commit
8093276
1 Parent(s): 3c6c618

Add Mistral API support due to my RunPod serverless system reliability issues

Browse files
Files changed (1) hide show
  1. utils.py +62 -3
utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  from os import environ as env
3
  from typing import Any, Dict, Union
4
 
@@ -6,7 +7,7 @@ import requests
6
  from huggingface_hub import hf_hub_download
7
 
8
 
9
- # There are 3 ways to use the LLM model currently used:
10
  # 1. Use the HTTP server (USE_HTTP_SERVER=True), this is good for development
11
  # when you want to change the logic of the translator without restarting the server.
12
  # 2. Load the model into memory
@@ -17,12 +18,13 @@ from huggingface_hub import hf_hub_download
17
  # It's possible to switch to another LLM API by changing the llm_streaming function.
18
  # 3. Use the RunPod API, which is a paid service with severless GPU functions.
19
  # See serverless.md for more information.
 
20
 
21
  URL = "http://localhost:5834/v1/chat/completions"
22
  in_memory_llm = None
23
- worker_options = ["runpod", "http", "in_memory"]
24
 
25
- LLM_WORKER = env.get("LLM_WORKER", "runpod")
26
  if LLM_WORKER not in worker_options:
27
  raise ValueError(f"Invalid worker: {LLM_WORKER}")
28
  N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
@@ -184,9 +186,66 @@ def llm_stream_serverless(prompt,model):
184
 
185
  def query_ai_prompt(prompt, replacements, model_class):
186
  prompt = replace_text(prompt, replacements)
 
 
 
 
187
  if LLM_WORKER == "runpod":
188
  return llm_stream_serverless(prompt, model_class)
189
  if LLM_WORKER == "http":
190
  return llm_streaming(prompt, model_class)
191
  if LLM_WORKER == "in_memory":
192
  return llm_stream_sans_network(prompt, model_class)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ from time import time, sleep
3
  from os import environ as env
4
  from typing import Any, Dict, Union
5
 
 
7
  from huggingface_hub import hf_hub_download
8
 
9
 
10
+ # There are 4 ways to use a LLM model currently used:
11
  # 1. Use the HTTP server (USE_HTTP_SERVER=True), this is good for development
12
  # when you want to change the logic of the translator without restarting the server.
13
  # 2. Load the model into memory
 
18
  # It's possible to switch to another LLM API by changing the llm_streaming function.
19
  # 3. Use the RunPod API, which is a paid service with severless GPU functions.
20
  # See serverless.md for more information.
21
+ # 4. Use the Mistral API, which is a paid services.
22
 
23
  URL = "http://localhost:5834/v1/chat/completions"
24
  in_memory_llm = None
25
+ worker_options = ["runpod", "http", "in_memory", "mistral"]
26
 
27
+ LLM_WORKER = env.get("LLM_WORKER", "mistral")
28
  if LLM_WORKER not in worker_options:
29
  raise ValueError(f"Invalid worker: {LLM_WORKER}")
30
  N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
 
186
 
187
  def query_ai_prompt(prompt, replacements, model_class):
188
  prompt = replace_text(prompt, replacements)
189
+ if LLM_WORKER == "mistral":
190
+ return llm_stream_mistral_api(prompt, model_class)
191
+ if LLM_WORKER == "mistral":
192
+ return llm_stream_mistral_api(prompt, model_class)
193
  if LLM_WORKER == "runpod":
194
  return llm_stream_serverless(prompt, model_class)
195
  if LLM_WORKER == "http":
196
  return llm_streaming(prompt, model_class)
197
  if LLM_WORKER == "in_memory":
198
  return llm_stream_sans_network(prompt, model_class)
199
+
200
+
201
+
202
+ # Global variables to enforce rate limiting
203
+ LAST_REQUEST_TIME = None
204
+ REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
205
+
206
+ def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict[str, Any]]:
207
+ global LAST_REQUEST_TIME
208
+ current_time = time()
209
+ if LAST_REQUEST_TIME is not None:
210
+ elapsed_time = current_time - LAST_REQUEST_TIME
211
+ if elapsed_time < REQUEST_INTERVAL:
212
+ sleep_time = REQUEST_INTERVAL - elapsed_time
213
+ sleep(sleep_time)
214
+ print(f"Slept for {sleep_time} seconds to enforce rate limit")
215
+ LAST_REQUEST_TIME = time()
216
+
217
+ MISTRAL_API_URL = env.get("MISTRAL_API_URL", "https://api.mistral.ai/v1/chat/completions")
218
+ MISTRAL_API_KEY = env.get("MISTRAL_API_KEY", None)
219
+ if not MISTRAL_API_KEY:
220
+ raise ValueError("MISTRAL_API_KEY environment variable not set")
221
+ headers = {
222
+ 'Content-Type': 'application/json',
223
+ 'Accept': 'application/json',
224
+ 'Authorization': f'Bearer {MISTRAL_API_KEY}'
225
+ }
226
+ data = {
227
+ 'model': 'mistral-small-latest',
228
+ 'messages': [
229
+ {
230
+ 'role': 'user',
231
+ 'response_format': {'type': 'json_object'},
232
+ 'content': prompt
233
+ }
234
+ ]
235
+ }
236
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=data)
237
+ if response.status_code != 200:
238
+ raise ValueError(f"Unexpected Mistral API status code: {response.status_code} with body: {response.text}")
239
+ result = response.json()
240
+ print(result)
241
+ output = result['choices'][0]['message']['content']
242
+ if pydantic_model_class:
243
+ parsed_result = pydantic_model_class.model_validate_json(output)
244
+ print(parsed_result)
245
+ # This will raise an exception if the model is invalid,
246
+ # TODO: handle exception with retry logic
247
+ else:
248
+ print("No pydantic model class provided, returning without class validation")
249
+ return json.loads(output)
250
+
251
+