flor-6.3b / sagemaker_endpoint.py
PaulNdrei's picture
Initial commit
6cba6d8
raw
history blame
1.08 kB
import os
from traceback import print_exc
import boto3
from handler import ContentHandler
from dotenv import load_dotenv
load_dotenv()
endpoint_name = os.environ.get("AWS_ENDPOINT_NAME")
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
aws_region_name = os.environ.get("AWS_REGION_NAME")
boto_client = boto3.client(
service_name='sagemaker-runtime',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name)
content_handler = ContentHandler()
def invoke_endpoint(
input_,
model_parameters,
):
try:
response = boto_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType='application/json',
Body=content_handler.transform_input(prompt=input_, model_kwargs=model_parameters)
)
return content_handler.transform_output(response['Body'])
except:
print_exc()
return None