FastAPI / models_initialization /mistral_registry.py
raghavNCI
mistral after hosted on aws
f0f712f
raw
history blame
2.4 kB
import os
import json
import boto3
from botocore.config import Config
from botocore.exceptions import BotoCoreError, ClientError
# ──────────────────────────────────────────────────────────────
# Environment variables you need (add them in your HF Space)
# ──────────────────────────────────────────────────────────────
# AWS_ACCESS_KEY_ID
# AWS_SECRET_ACCESS_KEY
# AWS_REGION β†’ e.g. "us-east-1"
# SAGEMAKER_ENDPOINT_NAME β†’ e.g. "mistral-endpoint"
# ──────────────────────────────────────────────────────────────
AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT_NAME", "mistral-endpoint")
# Optional: configure retries / timeouts
boto_cfg = Config(
retries={"max_attempts": 3, "mode": "standard"},
connect_timeout=10,
read_timeout=120,
)
sm_client = boto3.client("sagemaker-runtime", region_name=AWS_REGION, config=boto_cfg)
def mistral_generate(
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.7,
) -> str:
"""
Call the SageMaker endpoint that hosts Mistral-7B.
Returns the generated text or an empty string on failure.
"""
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
},
}
try:
# Invoke the endpoint
response = sm_client.invoke_endpoint(
EndpointName=ENDPOINT,
ContentType="application/json",
Body=json.dumps(payload).encode("utf-8"),
)
# SageMaker returns a byte stream β†’ decode & load JSON
result = json.loads(response["Body"].read())
if isinstance(result, list) and result:
return result[0].get("generated_text", "").strip()
except (BotoCoreError, ClientError) as e:
# Log SageMaker errors (throttling, auth, etc.)
print("❌ SageMaker invocation error:", str(e))
except Exception as e:
print("❌ Unknown error:", str(e))
return ""