Spaces:
Runtime error
Runtime error
| """Wrapper around Sagemaker InvokeEndpoint API.""" | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Extra, root_validator | |
| from langchain.embeddings.base import Embeddings | |
| from langchain.llms.sagemaker_endpoint import ContentHandlerBase | |
| class SagemakerEndpointEmbeddings(BaseModel, Embeddings): | |
| """Wrapper around custom Sagemaker Inference Endpoints. | |
| To use, you must supply the endpoint name from your deployed | |
| Sagemaker model & the region where it is deployed. | |
| To authenticate, the AWS client uses the following methods to | |
| automatically load credentials: | |
| https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | |
| If a specific credential profile should be used, you must pass | |
| the name of the profile from the ~/.aws/credentials file that is to be used. | |
| Make sure the credentials / roles used have the required policies to | |
| access the Sagemaker endpoint. | |
| See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html | |
| """ | |
| """ | |
| Example: | |
| .. code-block:: python | |
| from langchain.embeddings import SagemakerEndpointEmbeddings | |
| endpoint_name = ( | |
| "my-endpoint-name" | |
| ) | |
| region_name = ( | |
| "us-west-2" | |
| ) | |
| credentials_profile_name = ( | |
| "default" | |
| ) | |
| se = SagemakerEndpointEmbeddings( | |
| endpoint_name=endpoint_name, | |
| region_name=region_name, | |
| credentials_profile_name=credentials_profile_name | |
| ) | |
| """ | |
| client: Any #: :meta private: | |
| endpoint_name: str = "" | |
| """The name of the endpoint from the deployed Sagemaker model. | |
| Must be unique within an AWS Region.""" | |
| region_name: str = "" | |
| """The aws region where the Sagemaker model is deployed, eg. `us-west-2`.""" | |
| credentials_profile_name: Optional[str] = None | |
| """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which | |
| has either access keys or role information specified. | |
| If not specified, the default credential profile or, if on an EC2 instance, | |
| credentials from IMDS will be used. | |
| See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | |
| """ | |
| content_handler: ContentHandlerBase | |
| """The content handler class that provides an input and | |
| output transform functions to handle formats between LLM | |
| and the endpoint. | |
| """ | |
| """ | |
| Example: | |
| .. code-block:: python | |
| from langchain.llms.sagemaker_endpoint import ContentHandlerBase | |
| class ContentHandler(ContentHandlerBase): | |
| content_type = "application/json" | |
| accepts = "application/json" | |
| def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: | |
| input_str = json.dumps({prompt: prompt, **model_kwargs}) | |
| return input_str.encode('utf-8') | |
| def transform_output(self, output: bytes) -> str: | |
| response_json = json.loads(output.read().decode("utf-8")) | |
| return response_json[0]["generated_text"] | |
| """ | |
| model_kwargs: Optional[Dict] = None | |
| """Key word arguments to pass to the model.""" | |
| endpoint_kwargs: Optional[Dict] = None | |
| """Optional attributes passed to the invoke_endpoint | |
| function. See `boto3`_. docs for more info. | |
| .. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html> | |
| """ | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| arbitrary_types_allowed = True | |
| def validate_environment(cls, values: Dict) -> Dict: | |
| """Validate that AWS credentials to and python package exists in environment.""" | |
| try: | |
| import boto3 | |
| try: | |
| if values["credentials_profile_name"] is not None: | |
| session = boto3.Session( | |
| profile_name=values["credentials_profile_name"] | |
| ) | |
| else: | |
| # use default credentials | |
| session = boto3.Session() | |
| values["client"] = session.client( | |
| "sagemaker-runtime", region_name=values["region_name"] | |
| ) | |
| except Exception as e: | |
| raise ValueError( | |
| "Could not load credentials to authenticate with AWS client. " | |
| "Please check that credentials in the specified " | |
| "profile name are valid." | |
| ) from e | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import boto3 python package. " | |
| "Please it install it with `pip install boto3`." | |
| ) | |
| return values | |
| def _embedding_func(self, texts: List[str]) -> List[float]: | |
| """Call out to SageMaker Inference embedding endpoint.""" | |
| # replace newlines, which can negatively affect performance. | |
| texts = list(map(lambda x: x.replace("\n", " "), texts)) | |
| _model_kwargs = self.model_kwargs or {} | |
| _endpoint_kwargs = self.endpoint_kwargs or {} | |
| body = self.content_handler.transform_input(texts, _model_kwargs) | |
| content_type = self.content_handler.content_type | |
| accepts = self.content_handler.accepts | |
| # send request | |
| try: | |
| response = self.client.invoke_endpoint( | |
| EndpointName=self.endpoint_name, | |
| Body=body, | |
| ContentType=content_type, | |
| Accept=accepts, | |
| **_endpoint_kwargs, | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Error raised by inference endpoint: {e}") | |
| return self.content_handler.transform_output(response["Body"]) | |
| def embed_documents( | |
| self, texts: List[str], chunk_size: int = 64 | |
| ) -> List[List[float]]: | |
| """Compute doc embeddings using a SageMaker Inference Endpoint. | |
| Args: | |
| texts: The list of texts to embed. | |
| chunk_size: The chunk size defines how many input texts will | |
| be grouped together as request. If None, will use the | |
| chunk size specified by the class. | |
| Returns: | |
| List of embeddings, one for each text. | |
| """ | |
| results = [] | |
| _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size | |
| for i in range(0, len(texts), _chunk_size): | |
| response = self._embedding_func(texts[i : i + _chunk_size]) | |
| results.append(response) | |
| return results | |
| def embed_query(self, text: str) -> List[float]: | |
| """Compute query embeddings using a SageMaker inference endpoint. | |
| Args: | |
| text: The text to embed. | |
| Returns: | |
| Embeddings for the text. | |
| """ | |
| return self._embedding_func([text]) | |