|
import os |
|
import json |
|
import requests |
|
from requests.exceptions import RequestException |
|
|
|
HF_ENDPOINT_URL = os.getenv("SAGEMAKER_ENDPOINT_NAME") |
|
HF_ENDPOINT_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
assert HF_ENDPOINT_URL, "β HF_ENDPOINT_URL is not set" |
|
assert HF_ENDPOINT_TOKEN, "β HF_ENDPOINT_TOKEN is not set" |
|
|
|
HEADERS = { |
|
"Authorization": f"Bearer {HF_ENDPOINT_TOKEN}", |
|
"Content-Type": "application/json", |
|
"Accept": "application/json", |
|
} |
|
|
|
|
|
def mistral_generate( |
|
prompt: str, |
|
max_new_tokens: int = 128, |
|
temperature: float = 0.7, |
|
) -> str: |
|
""" |
|
Call the Hugging Face Inference Endpoint that hosts Mistral-7B. |
|
Returns the generated text, or an empty string on failure. |
|
""" |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
}, |
|
} |
|
|
|
try: |
|
r = requests.post( |
|
HF_ENDPOINT_URL, |
|
headers=HEADERS, |
|
json=payload, |
|
timeout=90, |
|
) |
|
r.raise_for_status() |
|
data = r.json() |
|
|
|
|
|
if isinstance(data, list) and data: |
|
return data[0].get("generated_text", "").strip() |
|
|
|
if isinstance(data, dict) and "generated_text" in data: |
|
return data["generated_text"].strip() |
|
|
|
except RequestException as e: |
|
print("β HF Endpoint error:", str(e)) |
|
if e.response is not None: |
|
print("Endpoint said:", e.response.text[:300]) |
|
|
|
except Exception as e: |
|
print("β Unknown error:", str(e)) |
|
|
|
return "" |
|
|