FD900 commited on
Commit
f960c92
·
verified ·
1 Parent(s): 2b2842b

Update mistral_hf_wrapper.py

Browse files
Files changed (1) hide show
  1. mistral_hf_wrapper.py +19 -24
mistral_hf_wrapper.py CHANGED
@@ -1,37 +1,32 @@
1
- # mistral_hf_wrapper.py
2
-
3
  import os
4
  import requests
5
 
6
  class MistralInference:
7
- def __init__(self):
8
- self.api_url = os.getenv("HF_MISTRAL_URL")
9
- self.api_token = os.getenv("HF_TOKEN")
10
- if not self.api_url or not self.api_token:
11
- raise ValueError("Missing HF_MISTRAL_URL or HF_TOKEN environment variables")
12
 
13
  def run(self, prompt: str) -> str:
14
  headers = {
15
  "Authorization": f"Bearer {self.api_token}",
16
  "Content-Type": "application/json"
17
  }
18
-
19
  payload = {
20
  "inputs": prompt,
21
- "parameters": {
22
- "max_new_tokens": 512,
23
- "temperature": 0.7,
24
- "return_full_text": False
25
- }
26
  }
27
-
28
- response = requests.post(self.api_url, headers=headers, json=payload)
29
- response.raise_for_status()
30
-
31
- output = response.json()
32
- if isinstance(output, list) and "generated_text" in output[0]:
33
- return output[0]["generated_text"]
34
- elif isinstance(output, dict) and "generated_text" in output:
35
- return output["generated_text"]
36
- else:
37
- raise ValueError("Unexpected response format from Mistral endpoint")
 
 
 
 
 
 
 
1
  import os
2
  import requests
3
 
4
  class MistralInference:
5
+ def __init__(self, api_url=None, api_token=None):
6
+ self.api_url = api_url or os.getenv("HF_MISTRAL_ENDPOINT")
7
+ self.api_token = api_token or os.getenv("HF_TOKEN")
 
 
8
 
9
  def run(self, prompt: str) -> str:
10
  headers = {
11
  "Authorization": f"Bearer {self.api_token}",
12
  "Content-Type": "application/json"
13
  }
 
14
  payload = {
15
  "inputs": prompt,
16
+ "parameters": {"max_new_tokens": 512}
 
 
 
 
17
  }
18
+ try:
19
+ response = requests.post(self.api_url, headers=headers, json=payload)
20
+ response.raise_for_status()
21
+ output = response.json()
22
+ # Check different possible keys depending on model
23
+ if isinstance(output, list) and "generated_text" in output[0]:
24
+ return output[0]["generated_text"]
25
+ elif "generated_text" in output:
26
+ return output["generated_text"]
27
+ elif "text" in output:
28
+ return output["text"]
29
+ else:
30
+ return str(output)
31
+ except Exception as e:
32
+ return f"Error querying Mistral: {str(e)}"