File size: 10,674 Bytes
a0db240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# pylint: disable=W0707
# pylint: disable=W0719

import os
import json
import tiktoken
import openai
from openai import OpenAI
import requests

from constants.cli import OPENAI_MODELS
from constants.ai import SYSTEM_PROMPT, PROMPT, API_URL


def retrieve(query, k=10, filters=None):
    """Retrieves and returns dict.

    Args:
        query (str): User query to pass in
        openai_api_key (str): openai api key. If not passed in, uses environment variable
        k (int, optional): number of results passed back. Defaults to 10.
        filters (dict, optional): Filters to apply to the query. You can filter based off
            any piece of metadata by passing in a dict of the format {metadata_name: filter_value}
            ie {"library_id": "1234"}.

            See the README for more details:
            https://github.com/fleet-ai/context/tree/main#using-fleet-contexts-rich-metadata

    Returns:
        list: List of queried results
    """

    url = f"{API_URL}/query"
    params = {
        "query": query,
        "dataset": "python_libraries",
        "n_results": k,
        "filters": filters,
    }
    return requests.post(url, json=params, timeout=120).json()


def retrieve_context(query, openai_api_key, k=10, filters=None):
    """Gets the context from our libraries vector db for a given query.

    Args:
        query (str): User input query
        k (int, optional): number of retrieved results. Defaults to 10.
    """

    # First, we query the API
    responses = retrieve(query, k=k, filters=filters)

    # Then, we build the prompt_with_context string
    prompt_with_context = ""
    for response in responses:
        prompt_with_context += f"\n\n### Context {response['metadata']['url']} ###\n{response['metadata']['text']}"
    return {"role": "user", "content": prompt_with_context}


def construct_prompt(
    messages,
    context_message,
    model="gpt-4-1106-preview",
    cite_sources=True,
    context_window=3000,
):
    """
    Constructs a RAG (Retrieval-Augmented Generation) prompt by balancing the token count of messages and context_message.
    If the total token count exceeds the maximum limit, it adjusts the token count of each to maintain a 1:1 proportion.
    It then combines both lists and returns the result.

    Parameters:
    messages (List[dict]): List of messages to be included in the prompt.
    context_message (dict): Context message to be included in the prompt.
    model (str): The model to be used for encoding, default is "gpt-4-1106-preview".

    Returns:
    List[dict]: The constructed RAG prompt.
    """
    # Get the encoding; default to cl100k_base
    if model in OPENAI_MODELS:
        encoding = tiktoken.encoding_for_model(model)
    else:
        encoding = tiktoken.get_encoding("cl100k_base")

    # 1) calculate tokens
    reserved_space = 1000
    max_messages_count = int((context_window - reserved_space) / 2)
    max_context_count = int((context_window - reserved_space) / 2)

    # 2) construct prompt
    prompts = messages.copy()
    prompts.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
    if cite_sources:
        prompts.insert(-1, {"role": "user", "content": PROMPT})

    # 3) find how many tokens each list has
    messages_token_count = len(
        encoding.encode(
            "\n".join(
                [
                    f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>"
                    for message in prompts
                ]
            )
        )
    )
    context_token_count = len(
        encoding.encode(
            f"<|im_start|>{context_message['role']}\n{context_message['content']}<|im_end|>"
        )
    )

    # 4) Balance the token count for each
    if (messages_token_count + context_token_count) > (context_window - reserved_space):
        # context has more than limit, messages has less than limit
        if (messages_token_count < max_messages_count) and (
            context_token_count > max_context_count
        ):
            max_context_count += max_messages_count - messages_token_count
        # messages has more than limit, context has less than limit
        elif (messages_token_count > max_messages_count) and (
            context_token_count < max_context_count
        ):
            max_messages_count += max_context_count - context_token_count

    # 5) Cut each list to the max count

    # Cut down messages
    while messages_token_count > max_messages_count:
        removed_encoding = encoding.encode(
            f"<|im_start|>{prompts[1]['role']}\n{prompts[1]['content']}<|im_end|>"
        )
        messages_token_count -= len(removed_encoding)
        if messages_token_count < max_messages_count:
            prompts = (
                [prompts[0]]
                + [
                    {
                        "role": prompts[1]["role"],
                        "content": encoding.decode(
                            removed_encoding[
                                : min(
                                    int(max_messages_count -
                                        messages_token_count),
                                    len(removed_encoding),
                                )
                            ]
                        )
                        .replace("<|im_start|>", "")
                        .replace("<|im_end|>", ""),
                    }
                ]
                + prompts[2:]
            )
        else:
            prompts = [prompts[0]] + prompts[2:]

    # Cut down context
    if context_token_count > max_context_count:
        # Taking a proportion of the content chars length
        reduced_chars_length = int(
            len(context_message["content"]) *
            (max_context_count / context_token_count)
        )
        context_message["content"] = context_message["content"][:reduced_chars_length]

    # 6) Combine both lists
    prompts.insert(-1, context_message)

    return prompts


def get_remote_chat_response(messages, model="gpt-4-1106-preview"):
    """
    Returns a streamed OpenAI chat response.

    Parameters:
    messages (List[dict]): List of messages to be included in the prompt.
    model (str): The model to be used for encoding, default is "gpt-4-1106-preview".

    Returns:
    str: The streamed OpenAI chat response.
    """
    client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

    try:
        response = client.chat.completions.create(
            model=model, messages=messages, temperature=0.2, stream=True
        )

        for chunk in response:
            current_context = chunk.choices[0].delta.content
            yield current_context

    except openai.AuthenticationError as error:
        print("401 Authentication Error:", error)
        raise Exception(
            "Invalid OPENAI_API_KEY. Please re-run with a valid key.")

    except Exception as error:
        print("Streaming Error:", error)
        raise Exception("Internal Server Error")


def get_other_chat_response(messages, model="local-model"):
    """
    Returns a streamed chat response from a local server.

    Parameters:
    messages (List[dict]): List of messages to be included in the prompt.
    model (str): The model to be used for encoding, default is "gpt-4-1106-preview".

    Returns:
    str: The streamed chat response.
    """
    try:
        if model == "local-model":
            url = "http://localhost:1234/v1/chat/completions"
            headers = {"Content-Type": "application/json"}
            data = {
                "messages": messages,
                "temperature": 0.2,
                "max_tokens": -1,
                "stream": True,
            }
            response = requests.post(
                url, headers=headers, data=json.dumps(data), stream=True, timeout=120
            )

            if response.status_code == 200:
                for chunk in response.iter_content(chunk_size=None):
                    decoded_chunk = chunk.decode()
                    if (
                        "data:" in decoded_chunk
                        and decoded_chunk.split("data:")[1].strip()
                    ):  # Check if the chunk is not empty
                        try:
                            chunk_dict = json.loads(
                                decoded_chunk.split("data:")[1].strip()
                            )
                            yield chunk_dict["choices"][0]["delta"].get("content", "")
                        except json.JSONDecodeError:
                            pass
            else:
                print(f"Error: {response.status_code}, {response.text}")
                raise Exception("Internal Server Error")
        else:
            if not os.environ.get("OPENROUTER_API_KEY"):
                raise Exception(
                    f"For non-OpenAI models, like {model}, set your OPENROUTER_API_KEY."
                )

            response = requests.post(
                url="https://openrouter.ai/api/v1/chat/completions",
                headers={
                    "Authorization": f"Bearer {os.environ.get('OPENROUTER_API_KEY')}",
                    "HTTP-Referer": os.environ.get(
                        "OPENROUTER_APP_URL", "https://fleet.so/context"
                    ),
                    "X-Title": os.environ.get("OPENROUTER_APP_TITLE", "Fleet Context"),
                    "Content-Type": "application/json",
                },
                data=json.dumps(
                    {"model": model, "messages": messages, "stream": True}),
                stream=True,
                timeout=120,
            )
            if response.status_code == 200:
                for chunk in response.iter_lines():
                    decoded_chunk = chunk.decode("utf-8")
                    if (
                        "data:" in decoded_chunk
                        and decoded_chunk.split("data:")[1].strip()
                    ):  # Check if the chunk is not empty
                        try:
                            chunk_dict = json.loads(
                                decoded_chunk.split("data:")[1].strip()
                            )
                            yield chunk_dict["choices"][0]["delta"].get("content", "")
                        except json.JSONDecodeError:
                            pass
            else:
                print(f"Error: {response.status_code}, {response.text}")
                raise Exception("Internal Server Error")

    except requests.exceptions.RequestException as error:
        print("Request Error:", error)
        raise Exception(
            "Invalid request. Please check your request parameters.")