File size: 1,722 Bytes
f00f379
17fbf3d
e30a3df
 
f00f379
02e2d96
 
f00f379
e30a3df
 
f00f379
e30a3df
 
 
 
 
f0f712f
 
 
 
 
 
 
17fbf3d
e30a3df
 
17fbf3d
f00f379
 
 
 
f0f712f
 
f00f379
 
 
e30a3df
 
 
 
 
17fbf3d
e30a3df
 
 
 
 
 
 
 
 
 
 
 
 
 
17fbf3d
f0f712f
 
f00f379
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import json
import requests
from requests.exceptions import RequestException

HF_ENDPOINT_URL   = os.getenv("SAGEMAKER_ENDPOINT_NAME")
HF_ENDPOINT_TOKEN = os.getenv("HF_TOKEN")

assert HF_ENDPOINT_URL,   "❌ HF_ENDPOINT_URL is not set"
assert HF_ENDPOINT_TOKEN, "❌ HF_ENDPOINT_TOKEN is not set"

HEADERS = {
    "Authorization": f"Bearer {HF_ENDPOINT_TOKEN}",
    "Content-Type":  "application/json",
    "Accept":        "application/json",
}


def mistral_generate(
    prompt: str,
    max_new_tokens: int = 128,
    temperature: float = 0.7,
) -> str:
    """
    Call the Hugging Face Inference Endpoint that hosts Mistral-7B.
    Returns the generated text, or an empty string on failure.
    """
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature":    temperature,
        },
    }

    try:
        r = requests.post(
            HF_ENDPOINT_URL,
            headers=HEADERS,
            json=payload,
            timeout=90,        # HF spins up cold endpoints too
        )
        r.raise_for_status()
        data = r.json()

        # HF Endpoints usually return a *list* of dicts
        if isinstance(data, list) and data:
            return data[0].get("generated_text", "").strip()
        # Some endpoints return a single dict
        if isinstance(data, dict) and "generated_text" in data:
            return data["generated_text"].strip()

    except RequestException as e:
        print("❌ HF Endpoint error:", str(e))
        if e.response is not None:
            print("Endpoint said:", e.response.text[:300])

    except Exception as e:
        print("❌ Unknown error:", str(e))

    return ""