File size: 9,267 Bytes
327982a
8093276
74d6e52
327982a
 
469f650
74d6e52
139217d
56e785c
8093276
74d6e52
 
 
 
a96b492
 
 
6c32632
56e785c
469f650
8093276
327982a
a96b492
ddb0d91
8093276
74d6e52
8093276
56e785c
 
 
 
74d6e52
56e785c
e01e28e
e327a9e
74d6e52
469f650
 
 
74d6e52
469f650
74d6e52
 
 
 
 
 
469f650
 
74d6e52
56e785c
74d6e52
e01e28e
327982a
e4b918c
 
 
327982a
e4b918c
f84c1a6
327982a
 
e4b918c
327982a
 
 
 
 
e01e28e
327982a
e01e28e
e4b918c
327982a
 
 
 
 
e4b918c
a96b492
e4b918c
 
 
 
327982a
 
 
 
 
 
 
 
 
e4b918c
327982a
 
e4b918c
3ebb6e1
327982a
 
 
 
 
 
 
 
e4b918c
327982a
 
 
 
 
e4b918c
139217d
 
 
 
 
 
ddb0d91
 
 
 
 
 
 
 
 
 
 
 
 
 
a0f49a0
ddb0d91
e01e28e
 
ddb0d91
a0f49a0
 
 
 
 
 
 
 
ddb0d91
a0f49a0
ddb0d91
 
 
 
 
 
 
 
9475016
56e785c
469f650
 
00af17e
 
56e785c
9475016
56e785c
 
 
 
 
 
 
 
 
 
 
 
 
 
00af17e
56e785c
469f650
3c6c618
 
c013599
 
56e785c
 
8093276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124003a
 
 
 
 
 
 
 
 
 
 
 
 
8093276
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import json
from time import time, sleep
from os import environ as env
from typing import Any, Dict, Union

import requests
from huggingface_hub import hf_hub_download  


# There are 4 ways to use a LLM model currently used:
# 1. Use the HTTP server (USE_HTTP_SERVER=True), this is good for development
# when you want to change the logic of the translator without restarting the server.
# 2. Load the model into memory
# When using the HTTP server, it must be ran separately. See the README for instructions.
# The llama_cpp Python HTTP server communicates with the AI model, similar 
# to the OpenAI API but adds a unique "grammar" parameter.
# The real OpenAI API has other ways to set the output format.
# It's possible to switch to another LLM API by changing the llm_streaming function.
# 3. Use the RunPod API, which is a paid service with severless GPU functions.
# See serverless.md for more information.
# 4. Use the Mistral API, which is a paid services.

URL = "http://localhost:5834/v1/chat/completions"
in_memory_llm = None
worker_options = ["runpod", "http", "in_memory", "mistral"]

LLM_WORKER = env.get("LLM_WORKER", "mistral")
if LLM_WORKER not in worker_options:
    raise ValueError(f"Invalid worker: {LLM_WORKER}")
N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
CONTEXT_SIZE = int(env.get("CONTEXT_SIZE", 2048))
LLM_MODEL_PATH = env.get("LLM_MODEL_PATH", None)

MAX_TOKENS = int(env.get("MAX_TOKENS", 1000))
TEMPERATURE = float(env.get("TEMPERATURE", 0.3))

performing_local_inference = (LLM_WORKER == "in_memory" or LLM_WORKER == "http")

if LLM_MODEL_PATH and len(LLM_MODEL_PATH) > 0:
    print(f"Using local model from {LLM_MODEL_PATH}")
if performing_local_inference and not LLM_MODEL_PATH:
    print("No local LLM_MODEL_PATH environment variable set. We need a model, downloading model from HuggingFace Hub")
    LLM_MODEL_PATH =hf_hub_download(
        repo_id=env.get("REPO_ID", "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"),
        filename=env.get("MODEL_FILE", "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"),
    )
    print(f"Model downloaded to {LLM_MODEL_PATH}")
if LLM_WORKER == "http" or LLM_WORKER == "in_memory":
    from llama_cpp import Llama, LlamaGrammar, json_schema_to_gbnf

if in_memory_llm is None and LLM_WORKER == "in_memory":
    print("Loading model into memory. If you didn't want this, set the USE_HTTP_SERVER environment variable to 'true'.")
    in_memory_llm = Llama(model_path=LLM_MODEL_PATH, n_ctx=CONTEXT_SIZE, n_gpu_layers=N_GPU_LAYERS, verbose=True)

def llm_streaming(
    prompt: str, pydantic_model_class, return_pydantic_object=False
) -> Union[str, Dict[str, Any]]:
    schema = pydantic_model_class.model_json_schema()

    # Optional example field from schema, is not needed for the grammar generation
    if "example" in schema:
        del schema["example"]

    json_schema = json.dumps(schema)
    grammar = json_schema_to_gbnf(json_schema)

    payload = {
        "stream": True,
        "max_tokens": MAX_TOKENS,
        "grammar": grammar,
        "temperature": TEMPERATURE,
        "messages": [{"role": "user", "content": prompt}],
    }
    headers = {
        "Content-Type": "application/json",
    }

    response = requests.post(
        URL,
        headers=headers,
        json=payload,
        stream=True,
    )
    output_text = ""
    for chunk in response.iter_lines():
        if chunk:
            chunk = chunk.decode("utf-8")
            if chunk.startswith("data: "):
                chunk = chunk.split("data: ")[1]
                if chunk.strip() == "[DONE]":
                    break
                chunk = json.loads(chunk)
                new_token = chunk.get("choices")[0].get("delta").get("content")
                if new_token:
                    output_text = output_text + new_token
                    print(new_token, sep="", end="", flush=True)
    print('\n')

    if return_pydantic_object:
        model_object = pydantic_model_class.model_validate_json(output_text)
        return model_object
    else:
        json_output = json.loads(output_text)
        return json_output


def replace_text(template: str, replacements: dict) -> str:
    for key, value in replacements.items():
        template = template.replace(f"{{{key}}}", value)
    return template




def calculate_overall_score(faithfulness, spiciness):
    baseline_weight = 0.8
    overall = faithfulness + (1 - baseline_weight) * spiciness * faithfulness
    return overall


def llm_stream_sans_network(
    prompt: str, pydantic_model_class, return_pydantic_object=False
) -> Union[str, Dict[str, Any]]:
    schema = pydantic_model_class.model_json_schema()

    # Optional example field from schema, is not needed for the grammar generation
    if "example" in schema:
        del schema["example"]

    json_schema = json.dumps(schema)
    grammar = LlamaGrammar.from_json_schema(json_schema)

    stream = in_memory_llm(
        prompt,
        max_tokens=MAX_TOKENS,
        temperature=TEMPERATURE,
        grammar=grammar,
        stream=True
    )

    output_text = ""
    for chunk in stream:
        result = chunk["choices"][0]
        print(result["text"], end='', flush=True)
        output_text = output_text + result["text"]

    print('\n')

    if return_pydantic_object:
        model_object = pydantic_model_class.model_validate_json(output_text)
        return model_object
    else:
        json_output = json.loads(output_text)
        return json_output


def llm_stream_serverless(prompt,model):
    RUNPOD_ENDPOINT_ID = env.get("RUNPOD_ENDPOINT_ID")
    RUNPOD_API_KEY = env.get("RUNPOD_API_KEY")
    assert RUNPOD_ENDPOINT_ID, "RUNPOD_ENDPOINT_ID environment variable not set"
    assert RUNPOD_API_KEY, "RUNPOD_API_KEY environment variable not set"
    url = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/runsync"

    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {RUNPOD_API_KEY}'
    }
    
    schema = model.schema()
    data = {
        'input': {
            'schema': json.dumps(schema),
            'prompt': prompt
        }
    }
    
    response = requests.post(url, json=data, headers=headers)
    assert response.status_code == 200, f"Unexpected RunPod API status code: {response.status_code} with body: {response.text}"
    result = response.json()
    print(result)
    # TODO: After a 30 second timeout, a job ID is returned in the response instead,
    # and the client must poll the job status endpoint to get the result.
    output = result['output'].replace("model:mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf\n", "")
    # TODO: remove replacement once new version of runpod is deployed
    return json.loads(output)

# Global variables to enforce rate limiting
LAST_REQUEST_TIME = None
REQUEST_INTERVAL = 0.5  # Minimum time interval between requests in seconds

def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict[str, Any]]:
    global LAST_REQUEST_TIME
    current_time = time()
    if LAST_REQUEST_TIME is not None:
        elapsed_time = current_time - LAST_REQUEST_TIME
        if elapsed_time < REQUEST_INTERVAL:
            sleep_time = REQUEST_INTERVAL - elapsed_time
            sleep(sleep_time)
            print(f"Slept for {sleep_time} seconds to enforce rate limit")
    LAST_REQUEST_TIME = time()

    MISTRAL_API_URL = env.get("MISTRAL_API_URL", "https://api.mistral.ai/v1/chat/completions")
    MISTRAL_API_KEY = env.get("MISTRAL_API_KEY", None)
    if not MISTRAL_API_KEY:
        raise ValueError("MISTRAL_API_KEY environment variable not set")
    headers = {
        'Content-Type': 'application/json',
        'Accept': 'application/json',
        'Authorization': f'Bearer {MISTRAL_API_KEY}'
    }
    data = {
        'model': 'mistral-small-latest',
        'messages': [
            {
                'role': 'user',
                'response_format': {'type': 'json_object'},
                'content': prompt
            }
        ]
    }
    response = requests.post(MISTRAL_API_URL, headers=headers, json=data)
    if response.status_code != 200:
        raise ValueError(f"Unexpected Mistral API status code: {response.status_code} with body: {response.text}")
    result = response.json()
    print(result)
    output = result['choices'][0]['message']['content']
    if pydantic_model_class:
        parsed_result = pydantic_model_class.model_validate_json(output)
        print(parsed_result)
        # This will raise an exception if the model is invalid,
        # TODO: handle exception with retry logic
    else:
        print("No pydantic model class provided, returning without class validation")
    return json.loads(output)

def query_ai_prompt(prompt, replacements, model_class):
    prompt = replace_text(prompt, replacements)
    if LLM_WORKER == "mistral":
        return llm_stream_mistral_api(prompt, model_class)
    if LLM_WORKER == "mistral":
        return llm_stream_mistral_api(prompt, model_class)
    if LLM_WORKER == "runpod":
        return llm_stream_serverless(prompt, model_class)
    if LLM_WORKER == "http":
        return llm_streaming(prompt, model_class)
    if LLM_WORKER == "in_memory":
        return llm_stream_sans_network(prompt, model_class)