File size: 3,250 Bytes
b585c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import typing
import json
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from pydantic.v1 import root_validator

from src.utils import FakeTokenizer


class ChatContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes:
        messages0 = []
        openai_system_prompt = "You are a helpful assistant."
        if openai_system_prompt:
            messages0.append({"role": "system", "content": openai_system_prompt})
        messages0.append({'role': 'user', 'content': prompt})
        input_dict = {'inputs': [messages0], "parameters": model_kwargs}
        return json.dumps(input_dict).encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generation"]['content']


class BaseContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes:
        input_dict = {'inputs': prompt, "parameters": model_kwargs}
        return json.dumps(input_dict).encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generation"]


class H2OSagemakerEndpoint(SagemakerEndpoint):
    aws_access_key_id: str = ""
    aws_secret_access_key: str = ""
    tokenizer: typing.Any = None

    @root_validator()
    def validate_environment(cls, values: typing.Dict) -> typing.Dict:
        """Validate that AWS credentials to and python package exists in environment."""
        try:
            import boto3

            try:
                if values["credentials_profile_name"] is not None:
                    session = boto3.Session(
                        profile_name=values["credentials_profile_name"]
                    )
                else:
                    # use default credentials
                    session = boto3.Session()

                values["client"] = session.client(
                    "sagemaker-runtime",
                    region_name=values['region_name'],
                    aws_access_key_id=values['aws_access_key_id'],
                    aws_secret_access_key=values['aws_secret_access_key'],
                )

            except Exception as e:
                raise ValueError(
                    "Could not load credentials to authenticate with AWS client. "
                    "Please check that credentials in the specified "
                    "profile name are valid."
                ) from e

        except ImportError:
            raise ImportError(
                "Could not import boto3 python package. "
                "Please install it with `pip install boto3`."
            )
        return values

    def get_token_ids(self, text: str) -> typing.List[int]:
        tokenizer = self.tokenizer
        if tokenizer is not None:
            return tokenizer.encode(text)
        else:
            return FakeTokenizer().encode(text)['input_ids']