FD900 commited on
Commit
a57fcec
·
verified ·
1 Parent(s): d6ccc6c

Update mistral_hf_wrapper.py

Browse files
Files changed (1) hide show
  1. mistral_hf_wrapper.py +17 -28
mistral_hf_wrapper.py CHANGED
@@ -1,32 +1,21 @@
1
- import requests
2
  import os
 
3
 
4
- class MistralInference:
5
- def __init__(self, api_url: str, api_token: str):
6
- self.api_url = api_url.rstrip("/")
7
- self.headers = {
8
- "Authorization": f"Bearer {api_token}",
9
- "Content-Type": "application/json"
10
- }
11
-
12
- def generate(self, prompt: str, temperature: float = 0.7, max_tokens: int = 512) -> str:
13
- payload = {
14
- "inputs": prompt,
15
- "parameters": {
16
- "temperature": temperature,
17
- "max_new_tokens": max_tokens,
18
- "return_full_text": False
19
- }
20
- }
21
-
22
- response = requests.post(
23
- f"{self.api_url}/generate",
24
- headers=self.headers,
25
- json=payload
26
- )
27
 
28
- if response.status_code != 200:
29
- raise RuntimeError(f"Request failed: {response.status_code} - {response.text}")
 
 
30
 
31
- data = response.json()
32
- return data["generated_text"] if "generated_text" in data else data[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import requests
3
 
4
+ API_URL = os.getenv("HF_MISTRAL_ENDPOINT")
5
+ API_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ headers = {
8
+ "Authorization": f"Bearer {API_TOKEN}",
9
+ "Content-Type": "application/json"
10
+ }
11
 
12
+ def query_mistral(system_prompt: str, user_prompt: str) -> str:
13
+ """Query the Mistral model hosted on Hugging Face."""
14
+ prompt = f"<s>[INST] {system_prompt.strip()}\n\n{user_prompt.strip()} [/INST]"
15
+ response = requests.post(
16
+ API_URL,
17
+ headers=headers,
18
+ json={"inputs": prompt}
19
+ )
20
+ response.raise_for_status()
21
+ return response.json()["generated_text"].strip()