File size: 4,920 Bytes
8a41f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from application.llm.base import BaseLLM
from application.core.settings import settings
import json
import io



class LineIterator:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

class SagemakerAPILLM(BaseLLM):

    def __init__(self, *args, **kwargs):
        import boto3
        runtime = boto3.client(
            'runtime.sagemaker',
            aws_access_key_id='xxx',
            aws_secret_access_key='xxx',
            region_name='us-west-2'
        )

        
        self.endpoint =  settings.SAGEMAKER_ENDPOINT
        self.runtime = runtime


    def gen(self, model, engine, messages, stream=False, **kwargs):
        context = messages[0]['content']
        user_question = messages[-1]['content']
        prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
    

        # Construct payload for endpoint
        payload = {
            "inputs": prompt,
            "stream": False,
            "parameters": {
                "do_sample": True,
                "temperature": 0.1,
                "max_new_tokens": 30,
                "repetition_penalty": 1.03,
                "stop": ["</s>", "###"]
            }
        }
        body_bytes = json.dumps(payload).encode('utf-8')

        # Invoke the endpoint
        response = self.runtime.invoke_endpoint(EndpointName=self.endpoint,
                                        ContentType='application/json',
                                        Body=body_bytes)
        result = json.loads(response['Body'].read().decode())
        import sys
        print(result[0]['generated_text'], file=sys.stderr)
        return result[0]['generated_text'][len(prompt):]

    def gen_stream(self, model, engine, messages, stream=True, **kwargs):
        context = messages[0]['content']
        user_question = messages[-1]['content']
        prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
    

        # Construct payload for endpoint
        payload = {
            "inputs": prompt,
            "stream": True,
            "parameters": {
                "do_sample": True,
                "temperature": 0.1,
                "max_new_tokens": 512,
                "repetition_penalty": 1.03,
                "stop": ["</s>", "###"]
            }
        }
        body_bytes = json.dumps(payload).encode('utf-8')

        # Invoke the endpoint
        response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint,
                                        ContentType='application/json',
                                        Body=body_bytes)
        #result = json.loads(response['Body'].read().decode())
        event_stream = response['Body']
        start_json = b'{'
        for line in LineIterator(event_stream):
            if line != b'' and start_json in line:
                #print(line)
                data = json.loads(line[line.find(start_json):].decode('utf-8'))
                if data['token']['text'] not in ["</s>", "###"]:
                    print(data['token']['text'],end='')
                    yield data['token']['text']