File size: 12,624 Bytes
7db0ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import os, types
import json
from enum import Enum
import requests
import time
from typing import Callable, Optional
import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx


class AlephAlphaError(Exception):
    def __init__(self, status_code, message):
        self.status_code = status_code
        self.message = message
        self.request = httpx.Request(
            method="POST", url="https://api.aleph-alpha.com/complete"
        )
        self.response = httpx.Response(status_code=status_code, request=self.request)
        super().__init__(
            self.message
        )  # Call the base class constructor with the parameters it needs


class AlephAlphaConfig:
    """
    Reference: https://docs.aleph-alpha.com/api/complete/

    The `AlephAlphaConfig` class represents the configuration for the Aleph Alpha API. Here are the properties:

    - `maximum_tokens` (integer, required): The maximum number of tokens to be generated by the completion. The sum of input tokens and maximum tokens may not exceed 2048.

    - `minimum_tokens` (integer, optional; default value: 0): Generate at least this number of tokens before an end-of-text token is generated.

    - `echo` (boolean, optional; default value: false): Whether to echo the prompt in the completion.

    - `temperature` (number, nullable; default value: 0): Adjusts how creatively the model generates outputs. Use combinations of temperature, top_k, and top_p sensibly.

    - `top_k` (integer, nullable; default value: 0): Introduces randomness into token generation by considering the top k most likely options.

    - `top_p` (number, nullable; default value: 0): Adds randomness by considering the smallest set of tokens whose cumulative probability exceeds top_p.

    - `presence_penalty`, `frequency_penalty`, `sequence_penalty` (number, nullable; default value: 0): Various penalties that can reduce repetition.

    - `sequence_penalty_min_length` (integer; default value: 2): Minimum number of tokens to be considered as a sequence.

    - `repetition_penalties_include_prompt`, `repetition_penalties_include_completion`, `use_multiplicative_presence_penalty`,`use_multiplicative_frequency_penalty`,`use_multiplicative_sequence_penalty` (boolean, nullable; default value: false): Various settings that adjust how the repetition penalties are applied.

    - `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties.

    - `penalty_exceptions` (string[], nullable): Strings that may be generated without penalty.

    - `penalty_exceptions_include_stop_sequences` (boolean, nullable; default value: true): Include all stop_sequences in penalty_exceptions.

    - `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side.

    - `n` (integer, nullable; default value: 1): The number of completions to return.

    - `logit_bias` (object, nullable): Adjust the logit scores before sampling.

    - `log_probs` (integer, nullable): Number of top log probabilities for each token generated.

    - `stop_sequences` (string[], nullable): List of strings that will stop generation if they're generated.

    - `tokens` (boolean, nullable; default value: false): Flag indicating whether individual tokens of the completion should be returned or not.

    - `raw_completion` (boolean; default value: false): if True, the raw completion of the model will be returned.

    - `disable_optimizations` (boolean, nullable; default value: false): Disables any applied optimizations to both your prompt and completion.

    - `completion_bias_inclusion`, `completion_bias_exclusion` (string[], default value: []): Set of strings to bias the generation of tokens.

    - `completion_bias_inclusion_first_token_only`, `completion_bias_exclusion_first_token_only` (boolean; default value: false): Consider only the first token for the completion_bias_inclusion/exclusion.

    - `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled.

    - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
    """

    maximum_tokens: Optional[
        int
    ] = litellm.max_tokens  # aleph alpha requires max tokens
    minimum_tokens: Optional[int] = None
    echo: Optional[bool] = None
    temperature: Optional[int] = None
    top_k: Optional[int] = None
    top_p: Optional[int] = None
    presence_penalty: Optional[int] = None
    frequency_penalty: Optional[int] = None
    sequence_penalty: Optional[int] = None
    sequence_penalty_min_length: Optional[int] = None
    repetition_penalties_include_prompt: Optional[bool] = None
    repetition_penalties_include_completion: Optional[bool] = None
    use_multiplicative_presence_penalty: Optional[bool] = None
    use_multiplicative_frequency_penalty: Optional[bool] = None
    use_multiplicative_sequence_penalty: Optional[bool] = None
    penalty_bias: Optional[str] = None
    penalty_exceptions_include_stop_sequences: Optional[bool] = None
    best_of: Optional[int] = None
    n: Optional[int] = None
    logit_bias: Optional[dict] = None
    log_probs: Optional[int] = None
    stop_sequences: Optional[list] = None
    tokens: Optional[bool] = None
    raw_completion: Optional[bool] = None
    disable_optimizations: Optional[bool] = None
    completion_bias_inclusion: Optional[list] = None
    completion_bias_exclusion: Optional[list] = None
    completion_bias_inclusion_first_token_only: Optional[bool] = None
    completion_bias_exclusion_first_token_only: Optional[bool] = None
    contextual_control_threshold: Optional[int] = None
    control_log_additive: Optional[bool] = None

    def __init__(
        self,
        maximum_tokens: Optional[int] = None,
        minimum_tokens: Optional[int] = None,
        echo: Optional[bool] = None,
        temperature: Optional[int] = None,
        top_k: Optional[int] = None,
        top_p: Optional[int] = None,
        presence_penalty: Optional[int] = None,
        frequency_penalty: Optional[int] = None,
        sequence_penalty: Optional[int] = None,
        sequence_penalty_min_length: Optional[int] = None,
        repetition_penalties_include_prompt: Optional[bool] = None,
        repetition_penalties_include_completion: Optional[bool] = None,
        use_multiplicative_presence_penalty: Optional[bool] = None,
        use_multiplicative_frequency_penalty: Optional[bool] = None,
        use_multiplicative_sequence_penalty: Optional[bool] = None,
        penalty_bias: Optional[str] = None,
        penalty_exceptions_include_stop_sequences: Optional[bool] = None,
        best_of: Optional[int] = None,
        n: Optional[int] = None,
        logit_bias: Optional[dict] = None,
        log_probs: Optional[int] = None,
        stop_sequences: Optional[list] = None,
        tokens: Optional[bool] = None,
        raw_completion: Optional[bool] = None,
        disable_optimizations: Optional[bool] = None,
        completion_bias_inclusion: Optional[list] = None,
        completion_bias_exclusion: Optional[list] = None,
        completion_bias_inclusion_first_token_only: Optional[bool] = None,
        completion_bias_exclusion_first_token_only: Optional[bool] = None,
        contextual_control_threshold: Optional[int] = None,
        control_log_additive: Optional[bool] = None,
    ) -> None:
        locals_ = locals()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return {
            k: v
            for k, v in cls.__dict__.items()
            if not k.startswith("__")
            and not isinstance(
                v,
                (
                    types.FunctionType,
                    types.BuiltinFunctionType,
                    classmethod,
                    staticmethod,
                ),
            )
            and v is not None
        }


def validate_environment(api_key):
    headers = {
        "accept": "application/json",
        "content-type": "application/json",
    }
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"
    return headers


def completion(
    model: str,
    messages: list,
    api_base: str,
    model_response: ModelResponse,
    print_verbose: Callable,
    encoding,
    api_key,
    logging_obj,
    optional_params=None,
    litellm_params=None,
    logger_fn=None,
    default_max_tokens_to_sample=None,
):
    headers = validate_environment(api_key)

    ## Load Config
    config = litellm.AlephAlphaConfig.get_config()
    for k, v in config.items():
        if (
            k not in optional_params
        ):  # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
            optional_params[k] = v

    completion_url = api_base
    model = model
    prompt = ""
    if "control" in model:  # follow the ###Instruction / ###Response format
        for idx, message in enumerate(messages):
            if "role" in message:
                if (
                    idx == 0
                ):  # set first message as instruction (required), let later user messages be input
                    prompt += f"###Instruction: {message['content']}"
                else:
                    if message["role"] == "system":
                        prompt += f"###Instruction: {message['content']}"
                    elif message["role"] == "user":
                        prompt += f"###Input: {message['content']}"
                    else:
                        prompt += f"###Response: {message['content']}"
            else:
                prompt += f"{message['content']}"
    else:
        prompt = " ".join(message["content"] for message in messages)
    data = {
        "model": model,
        "prompt": prompt,
        **optional_params,
    }

    ## LOGGING
    logging_obj.pre_call(
        input=prompt,
        api_key=api_key,
        additional_args={"complete_input_dict": data},
    )
    ## COMPLETION CALL
    response = requests.post(
        completion_url,
        headers=headers,
        data=json.dumps(data),
        stream=optional_params["stream"] if "stream" in optional_params else False,
    )
    if "stream" in optional_params and optional_params["stream"] == True:
        return response.iter_lines()
    else:
        ## LOGGING
        logging_obj.post_call(
            input=prompt,
            api_key=api_key,
            original_response=response.text,
            additional_args={"complete_input_dict": data},
        )
        print_verbose(f"raw model_response: {response.text}")
        ## RESPONSE OBJECT
        completion_response = response.json()
        if "error" in completion_response:
            raise AlephAlphaError(
                message=completion_response["error"],
                status_code=response.status_code,
            )
        else:
            try:
                choices_list = []
                for idx, item in enumerate(completion_response["completions"]):
                    if len(item["completion"]) > 0:
                        message_obj = Message(content=item["completion"])
                    else:
                        message_obj = Message(content=None)
                    choice_obj = Choices(
                        finish_reason=item["finish_reason"],
                        index=idx + 1,
                        message=message_obj,
                    )
                    choices_list.append(choice_obj)
                model_response["choices"] = choices_list
            except:
                raise AlephAlphaError(
                    message=json.dumps(completion_response),
                    status_code=response.status_code,
                )

        ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
        prompt_tokens = len(encoding.encode(prompt))
        completion_tokens = len(
            encoding.encode(model_response["choices"][0]["message"]["content"])
        )

        model_response["created"] = int(time.time())
        model_response["model"] = model
        usage = Usage(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        )
        model_response.usage = usage
        return model_response


def embedding():
    # logic for parsing in - calling - parsing out model embedding calls
    pass