|
import os |
|
import json |
|
import boto3 |
|
from botocore.config import Config |
|
from botocore.exceptions import BotoCoreError, ClientError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AWS_REGION = os.getenv("AWS_REGION", "us-east-1") |
|
ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT_NAME", "mistral-endpoint") |
|
|
|
|
|
boto_cfg = Config( |
|
retries={"max_attempts": 3, "mode": "standard"}, |
|
connect_timeout=10, |
|
read_timeout=120, |
|
) |
|
|
|
sm_client = boto3.client("sagemaker-runtime", region_name=AWS_REGION, config=boto_cfg) |
|
|
|
|
|
def mistral_generate( |
|
prompt: str, |
|
max_new_tokens: int = 128, |
|
temperature: float = 0.7, |
|
) -> str: |
|
""" |
|
Call the SageMaker endpoint that hosts Mistral-7B. |
|
Returns the generated text or an empty string on failure. |
|
""" |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
}, |
|
} |
|
|
|
try: |
|
|
|
response = sm_client.invoke_endpoint( |
|
EndpointName=ENDPOINT, |
|
ContentType="application/json", |
|
Body=json.dumps(payload).encode("utf-8"), |
|
) |
|
|
|
|
|
result = json.loads(response["Body"].read()) |
|
|
|
if isinstance(result, list) and result: |
|
return result[0].get("generated_text", "").strip() |
|
|
|
except (BotoCoreError, ClientError) as e: |
|
|
|
print("β SageMaker invocation error:", str(e)) |
|
|
|
except Exception as e: |
|
print("β Unknown error:", str(e)) |
|
|
|
return "" |
|
|