cbys4 commited on
Commit
d9be087
·
unverified ·
1 Parent(s): c1554f5

Add files via upload

Browse files
Files changed (1) hide show
  1. PathRAG/llm.py +1104 -0
PathRAG/llm.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import json
4
+ import os
5
+ import re
6
+ import struct
7
+ from functools import lru_cache
8
+ from typing import List, Dict, Callable, Any, Union, Optional
9
+ import aioboto3
10
+ import aiohttp
11
+ import numpy as np
12
+ import ollama
13
+ import torch
14
+ import time
15
+ from openai import (
16
+ AsyncOpenAI,
17
+ APIConnectionError,
18
+ RateLimitError,
19
+ Timeout,
20
+ AsyncAzureOpenAI,
21
+ )
22
+ from pydantic import BaseModel, Field
23
+ from tenacity import (
24
+ retry,
25
+ stop_after_attempt,
26
+ wait_exponential,
27
+ retry_if_exception_type,
28
+ )
29
+ from transformers import AutoTokenizer, AutoModelForCausalLM
30
+
31
+ from .utils import (
32
+ wrap_embedding_func_with_attrs,
33
+ locate_json_string_body_from_string,
34
+ safe_unicode_decode,
35
+ logger,
36
+ )
37
+
38
+ import sys
39
+
40
+ if sys.version_info < (3, 9):
41
+ from typing import AsyncIterator
42
+ else:
43
+ from collections.abc import AsyncIterator
44
+
45
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
46
+
47
+
48
+ @retry(
49
+ stop=stop_after_attempt(3),
50
+ wait=wait_exponential(multiplier=1, min=4, max=10),
51
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
52
+ )
53
+ async def openai_complete_if_cache(
54
+ model,
55
+ prompt,
56
+ system_prompt=None,
57
+ history_messages=[],
58
+ base_url="https://api.openai.com/v1",
59
+ api_key="",
60
+ **kwargs,
61
+ ) -> str:
62
+ if api_key:
63
+ os.environ["OPENAI_API_KEY"] = api_key
64
+ time.sleep(2)
65
+ openai_async_client = (
66
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
67
+ )
68
+ kwargs.pop("hashing_kv", None)
69
+ kwargs.pop("keyword_extraction", None)
70
+ messages = []
71
+ if system_prompt:
72
+ messages.append({"role": "system", "content": system_prompt})
73
+ messages.extend(history_messages)
74
+ messages.append({"role": "user", "content": prompt})
75
+
76
+
77
+ logger.debug("===== Query Input to LLM =====")
78
+ logger.debug(f"Query: {prompt}")
79
+ logger.debug(f"System prompt: {system_prompt}")
80
+ logger.debug("Full context:")
81
+ if "response_format" in kwargs:
82
+ response = await openai_async_client.beta.chat.completions.parse(
83
+ model=model, messages=messages, **kwargs
84
+ )
85
+ else:
86
+ response = await openai_async_client.chat.completions.create(
87
+ model=model, messages=messages, **kwargs
88
+ )
89
+
90
+ if hasattr(response, "__aiter__"):
91
+
92
+ async def inner():
93
+ async for chunk in response:
94
+ content = chunk.choices[0].delta.content
95
+ if content is None:
96
+ continue
97
+ if r"\u" in content:
98
+ content = safe_unicode_decode(content.encode("utf-8"))
99
+ yield content
100
+
101
+ return inner()
102
+ else:
103
+ content = response.choices[0].message.content
104
+ if r"\u" in content:
105
+ content = safe_unicode_decode(content.encode("utf-8"))
106
+ return content
107
+
108
+
109
+ @retry(
110
+ stop=stop_after_attempt(3),
111
+ wait=wait_exponential(multiplier=1, min=4, max=10),
112
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
113
+ )
114
+ async def azure_openai_complete_if_cache(
115
+ model,
116
+ prompt,
117
+ system_prompt=None,
118
+ history_messages=[],
119
+ base_url=None,
120
+ api_key=None,
121
+ api_version=None,
122
+ **kwargs,
123
+ ):
124
+ if api_key:
125
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
126
+ if base_url:
127
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
128
+ if api_version:
129
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
130
+
131
+ openai_async_client = AsyncAzureOpenAI(
132
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
133
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
134
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
135
+ )
136
+ kwargs.pop("hashing_kv", None)
137
+ messages = []
138
+ if system_prompt:
139
+ messages.append({"role": "system", "content": system_prompt})
140
+ messages.extend(history_messages)
141
+ if prompt is not None:
142
+ messages.append({"role": "user", "content": prompt})
143
+
144
+ response = await openai_async_client.chat.completions.create(
145
+ model=model, messages=messages, **kwargs
146
+ )
147
+ content = response.choices[0].message.content
148
+
149
+ return content
150
+
151
+
152
+ class BedrockError(Exception):
153
+ """Generic error for issues related to Amazon Bedrock"""
154
+
155
+
156
+ @retry(
157
+ stop=stop_after_attempt(5),
158
+ wait=wait_exponential(multiplier=1, max=60),
159
+ retry=retry_if_exception_type((BedrockError)),
160
+ )
161
+ async def bedrock_complete_if_cache(
162
+ model,
163
+ prompt,
164
+ system_prompt=None,
165
+ history_messages=[],
166
+ aws_access_key_id=None,
167
+ aws_secret_access_key=None,
168
+ aws_session_token=None,
169
+ **kwargs,
170
+ ) -> str:
171
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
172
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
173
+ )
174
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
175
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
176
+ )
177
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
178
+ "AWS_SESSION_TOKEN", aws_session_token
179
+ )
180
+ kwargs.pop("hashing_kv", None)
181
+
182
+ messages = []
183
+ for history_message in history_messages:
184
+ message = copy.copy(history_message)
185
+ message["content"] = [{"text": message["content"]}]
186
+ messages.append(message)
187
+
188
+
189
+ messages.append({"role": "user", "content": [{"text": prompt}]})
190
+
191
+
192
+ args = {"modelId": model, "messages": messages}
193
+
194
+
195
+ if system_prompt:
196
+ args["system"] = [{"text": system_prompt}]
197
+
198
+
199
+ inference_params_map = {
200
+ "max_tokens": "maxTokens",
201
+ "top_p": "topP",
202
+ "stop_sequences": "stopSequences",
203
+ }
204
+ if inference_params := list(
205
+ set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
206
+ ):
207
+ args["inferenceConfig"] = {}
208
+ for param in inference_params:
209
+ args["inferenceConfig"][inference_params_map.get(param, param)] = (
210
+ kwargs.pop(param)
211
+ )
212
+
213
+
214
+ session = aioboto3.Session()
215
+ async with session.client("bedrock-runtime") as bedrock_async_client:
216
+ try:
217
+ response = await bedrock_async_client.converse(**args, **kwargs)
218
+ except Exception as e:
219
+ raise BedrockError(e)
220
+
221
+ return response["output"]["message"]["content"][0]["text"]
222
+
223
+
224
+ @lru_cache(maxsize=1)
225
+ def initialize_hf_model(model_name):
226
+ hf_tokenizer = AutoTokenizer.from_pretrained(
227
+ model_name, device_map="auto", trust_remote_code=True
228
+ )
229
+ hf_model = AutoModelForCausalLM.from_pretrained(
230
+ model_name, device_map="auto", trust_remote_code=True
231
+ )
232
+ if hf_tokenizer.pad_token is None:
233
+ hf_tokenizer.pad_token = hf_tokenizer.eos_token
234
+
235
+ return hf_model, hf_tokenizer
236
+
237
+
238
+ @retry(
239
+ stop=stop_after_attempt(3),
240
+ wait=wait_exponential(multiplier=1, min=4, max=10),
241
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
242
+ )
243
+ async def hf_model_if_cache(
244
+ model,
245
+ prompt,
246
+ system_prompt=None,
247
+ history_messages=[],
248
+ **kwargs,
249
+ ) -> str:
250
+ model_name = model
251
+ hf_model, hf_tokenizer = initialize_hf_model(model_name)
252
+ messages = []
253
+ if system_prompt:
254
+ messages.append({"role": "system", "content": system_prompt})
255
+ messages.extend(history_messages)
256
+ messages.append({"role": "user", "content": prompt})
257
+ kwargs.pop("hashing_kv", None)
258
+ input_prompt = ""
259
+ try:
260
+ input_prompt = hf_tokenizer.apply_chat_template(
261
+ messages, tokenize=False, add_generation_prompt=True
262
+ )
263
+ except Exception:
264
+ try:
265
+ ori_message = copy.deepcopy(messages)
266
+ if messages[0]["role"] == "system":
267
+ messages[1]["content"] = (
268
+ "<system>"
269
+ + messages[0]["content"]
270
+ + "</system>\n"
271
+ + messages[1]["content"]
272
+ )
273
+ messages = messages[1:]
274
+ input_prompt = hf_tokenizer.apply_chat_template(
275
+ messages, tokenize=False, add_generation_prompt=True
276
+ )
277
+ except Exception:
278
+ len_message = len(ori_message)
279
+ for msgid in range(len_message):
280
+ input_prompt = (
281
+ input_prompt
282
+ + "<"
283
+ + ori_message[msgid]["role"]
284
+ + ">"
285
+ + ori_message[msgid]["content"]
286
+ + "</"
287
+ + ori_message[msgid]["role"]
288
+ + ">\n"
289
+ )
290
+
291
+ input_ids = hf_tokenizer(
292
+ input_prompt, return_tensors="pt", padding=True, truncation=True
293
+ ).to("cuda")
294
+ inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
295
+ output = hf_model.generate(
296
+ **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
297
+ )
298
+ response_text = hf_tokenizer.decode(
299
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
300
+ )
301
+
302
+ return response_text
303
+
304
+
305
+ @retry(
306
+ stop=stop_after_attempt(3),
307
+ wait=wait_exponential(multiplier=1, min=4, max=10),
308
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
309
+ )
310
+ async def ollama_model_if_cache(
311
+ model,
312
+ prompt,
313
+ system_prompt=None,
314
+ history_messages=[],
315
+ **kwargs,
316
+ ) -> Union[str, AsyncIterator[str]]:
317
+ stream = True if kwargs.get("stream") else False
318
+ kwargs.pop("max_tokens", None)
319
+ host = kwargs.pop("host", None)
320
+ timeout = kwargs.pop("timeout", None)
321
+ kwargs.pop("hashing_kv", None)
322
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
323
+ messages = []
324
+ if system_prompt:
325
+ messages.append({"role": "system", "content": system_prompt})
326
+ messages.extend(history_messages)
327
+ messages.append({"role": "user", "content": prompt})
328
+
329
+ response = await ollama_client.chat(model=model, messages=messages, **kwargs)
330
+ if stream:
331
+ """cannot cache stream response"""
332
+
333
+ async def inner():
334
+ async for chunk in response:
335
+ yield chunk["message"]["content"]
336
+
337
+ return inner()
338
+ else:
339
+ return response["message"]["content"]
340
+
341
+
342
+ @lru_cache(maxsize=1)
343
+ def initialize_lmdeploy_pipeline(
344
+ model,
345
+ tp=1,
346
+ chat_template=None,
347
+ log_level="WARNING",
348
+ model_format="hf",
349
+ quant_policy=0,
350
+ ):
351
+ from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
352
+
353
+ lmdeploy_pipe = pipeline(
354
+ model_path=model,
355
+ backend_config=TurbomindEngineConfig(
356
+ tp=tp, model_format=model_format, quant_policy=quant_policy
357
+ ),
358
+ chat_template_config=(
359
+ ChatTemplateConfig(model_name=chat_template) if chat_template else None
360
+ ),
361
+ log_level="WARNING",
362
+ )
363
+ return lmdeploy_pipe
364
+
365
+
366
+ @retry(
367
+ stop=stop_after_attempt(3),
368
+ wait=wait_exponential(multiplier=1, min=4, max=10),
369
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
370
+ )
371
+ async def lmdeploy_model_if_cache(
372
+ model,
373
+ prompt,
374
+ system_prompt=None,
375
+ history_messages=[],
376
+ chat_template=None,
377
+ model_format="hf",
378
+ quant_policy=0,
379
+ **kwargs,
380
+ ) -> str:
381
+ """
382
+ Args:
383
+ model (str): The path to the model.
384
+ It could be one of the following options:
385
+ - i) A local directory path of a turbomind model which is
386
+ converted by `lmdeploy convert` command or download
387
+ from ii) and iii).
388
+ - ii) The model_id of a lmdeploy-quantized model hosted
389
+ inside a model repo on huggingface.co, such as
390
+ "InternLM/internlm-chat-20b-4bit",
391
+ "lmdeploy/llama2-chat-70b-4bit", etc.
392
+ - iii) The model_id of a model hosted inside a model repo
393
+ on huggingface.co, such as "internlm/internlm-chat-7b",
394
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
395
+ and so on.
396
+ chat_template (str): needed when model is a pytorch model on
397
+ huggingface.co, such as "internlm-chat-7b",
398
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
399
+ and when the model name of local path did not match the original model name in HF.
400
+ tp (int): tensor parallel
401
+ prompt (Union[str, List[str]]): input texts to be completed.
402
+ do_preprocess (bool): whether pre-process the messages. Default to
403
+ True, which means chat_template will be applied.
404
+ skip_special_tokens (bool): Whether or not to remove special tokens
405
+ in the decoding. Default to be True.
406
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
407
+ Default to be False, which means greedy decoding will be applied.
408
+ """
409
+ try:
410
+ import lmdeploy
411
+ from lmdeploy import version_info, GenerationConfig
412
+ except Exception:
413
+ raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
414
+ kwargs.pop("hashing_kv", None)
415
+ kwargs.pop("response_format", None)
416
+ max_new_tokens = kwargs.pop("max_tokens", 512)
417
+ tp = kwargs.pop("tp", 1)
418
+ skip_special_tokens = kwargs.pop("skip_special_tokens", True)
419
+ do_preprocess = kwargs.pop("do_preprocess", True)
420
+ do_sample = kwargs.pop("do_sample", False)
421
+ gen_params = kwargs
422
+
423
+ version = version_info
424
+ if do_sample is not None and version < (0, 6, 0):
425
+ raise RuntimeError(
426
+ "`do_sample` parameter is not supported by lmdeploy until "
427
+ f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
428
+ )
429
+ else:
430
+ do_sample = True
431
+ gen_params.update(do_sample=do_sample)
432
+
433
+ lmdeploy_pipe = initialize_lmdeploy_pipeline(
434
+ model=model,
435
+ tp=tp,
436
+ chat_template=chat_template,
437
+ model_format=model_format,
438
+ quant_policy=quant_policy,
439
+ log_level="WARNING",
440
+ )
441
+
442
+ messages = []
443
+ if system_prompt:
444
+ messages.append({"role": "system", "content": system_prompt})
445
+
446
+ messages.extend(history_messages)
447
+ messages.append({"role": "user", "content": prompt})
448
+
449
+ gen_config = GenerationConfig(
450
+ skip_special_tokens=skip_special_tokens,
451
+ max_new_tokens=max_new_tokens,
452
+ **gen_params,
453
+ )
454
+
455
+ response = ""
456
+ async for res in lmdeploy_pipe.generate(
457
+ messages,
458
+ gen_config=gen_config,
459
+ do_preprocess=do_preprocess,
460
+ stream_response=False,
461
+ session_id=1,
462
+ ):
463
+ response += res.response
464
+ return response
465
+
466
+
467
+ class GPTKeywordExtractionFormat(BaseModel):
468
+ high_level_keywords: List[str]
469
+ low_level_keywords: List[str]
470
+
471
+
472
+ async def openai_complete(
473
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
474
+ ) -> Union[str, AsyncIterator[str]]:
475
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
476
+ if keyword_extraction:
477
+ kwargs["response_format"] = "json"
478
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
479
+ return await openai_complete_if_cache(
480
+ model_name,
481
+ prompt,
482
+ system_prompt=system_prompt,
483
+ history_messages=history_messages,
484
+ **kwargs,
485
+ )
486
+
487
+
488
+ async def gpt_4o_complete(
489
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
490
+ ) -> str:
491
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
492
+ if keyword_extraction:
493
+ kwargs["response_format"] = GPTKeywordExtractionFormat
494
+ return await openai_complete_if_cache(
495
+ "gpt-4o",
496
+ prompt,
497
+ system_prompt=system_prompt,
498
+ history_messages=history_messages,
499
+ **kwargs,
500
+ )
501
+
502
+
503
+ async def gpt_4o_mini_complete(
504
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
505
+ ) -> str:
506
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
507
+ if keyword_extraction:
508
+ kwargs["response_format"] = GPTKeywordExtractionFormat
509
+ return await openai_complete_if_cache(
510
+ "gpt-4o-mini",
511
+ prompt,
512
+ system_prompt=system_prompt,
513
+ history_messages=history_messages,
514
+ **kwargs,
515
+ )
516
+
517
+
518
+ async def nvidia_openai_complete(
519
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
520
+ ) -> str:
521
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
522
+ result = await openai_complete_if_cache(
523
+ "nvidia/llama-3.1-nemotron-70b-instruct",
524
+ prompt,
525
+ system_prompt=system_prompt,
526
+ history_messages=history_messages,
527
+ base_url="https://integrate.api.nvidia.com/v1",
528
+ **kwargs,
529
+ )
530
+ if keyword_extraction: # TODO: use JSON API
531
+ return locate_json_string_body_from_string(result)
532
+ return result
533
+
534
+
535
+ async def azure_openai_complete(
536
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
537
+ ) -> str:
538
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
539
+ result = await azure_openai_complete_if_cache(
540
+ "conversation-4o-mini",
541
+ prompt,
542
+ system_prompt=system_prompt,
543
+ history_messages=history_messages,
544
+ **kwargs,
545
+ )
546
+ if keyword_extraction: # TODO: use JSON API
547
+ return locate_json_string_body_from_string(result)
548
+ return result
549
+
550
+
551
+ async def bedrock_complete(
552
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
553
+ ) -> str:
554
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
555
+ result = await bedrock_complete_if_cache(
556
+ "anthropic.claude-3-haiku-20240307-v1:0",
557
+ prompt,
558
+ system_prompt=system_prompt,
559
+ history_messages=history_messages,
560
+ **kwargs,
561
+ )
562
+ if keyword_extraction: # TODO: use JSON API
563
+ return locate_json_string_body_from_string(result)
564
+ return result
565
+
566
+
567
+ async def hf_model_complete(
568
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
569
+ ) -> str:
570
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
571
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
572
+ result = await hf_model_if_cache(
573
+ model_name,
574
+ prompt,
575
+ system_prompt=system_prompt,
576
+ history_messages=history_messages,
577
+ **kwargs,
578
+ )
579
+ if keyword_extraction: # TODO: use JSON API
580
+ return locate_json_string_body_from_string(result)
581
+ return result
582
+
583
+
584
+ async def ollama_model_complete(
585
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
586
+ ) -> Union[str, AsyncIterator[str]]:
587
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
588
+ if keyword_extraction:
589
+ kwargs["format"] = "json"
590
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
591
+ return await ollama_model_if_cache(
592
+ model_name,
593
+ prompt,
594
+ system_prompt=system_prompt,
595
+ history_messages=history_messages,
596
+ **kwargs,
597
+ )
598
+
599
+
600
+ @retry(
601
+ stop=stop_after_attempt(3),
602
+ wait=wait_exponential(multiplier=1, min=4, max=10),
603
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
604
+ )
605
+ async def zhipu_complete_if_cache(
606
+ prompt: Union[str, List[Dict[str, str]]],
607
+ model: str = "glm-4-flashx",
608
+ api_key: Optional[str] = None,
609
+ system_prompt: Optional[str] = None,
610
+ history_messages: List[Dict[str, str]] = [],
611
+ **kwargs,
612
+ ) -> str:
613
+
614
+ try:
615
+ from zhipuai import ZhipuAI
616
+ except ImportError:
617
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
618
+
619
+ if api_key:
620
+ client = ZhipuAI(api_key=api_key)
621
+ else:
622
+ client = ZhipuAI()
623
+
624
+ messages = []
625
+
626
+ if not system_prompt:
627
+ system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"
628
+
629
+
630
+ if system_prompt:
631
+ messages.append({"role": "system", "content": system_prompt})
632
+ messages.extend(history_messages)
633
+ messages.append({"role": "user", "content": prompt})
634
+
635
+
636
+ logger.debug("===== Query Input to LLM =====")
637
+ logger.debug(f"Query: {prompt}")
638
+ logger.debug(f"System prompt: {system_prompt}")
639
+
640
+
641
+ kwargs = {
642
+ k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
643
+ }
644
+
645
+ response = client.chat.completions.create(model=model, messages=messages, **kwargs)
646
+
647
+ return response.choices[0].message.content
648
+
649
+
650
+ async def zhipu_complete(
651
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
652
+ ):
653
+
654
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
655
+
656
+ if keyword_extraction:
657
+ extraction_prompt = """You are a helpful assistant that extracts keywords from text.
658
+ Please analyze the content and extract two types of keywords:
659
+ 1. High-level keywords: Important concepts and main themes
660
+ 2. Low-level keywords: Specific details and supporting elements
661
+
662
+ Return your response in this exact JSON format:
663
+ {
664
+ "high_level_keywords": ["keyword1", "keyword2"],
665
+ "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
666
+ }
667
+
668
+ Only return the JSON, no other text."""
669
+
670
+
671
+ if system_prompt:
672
+ system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
673
+ else:
674
+ system_prompt = extraction_prompt
675
+
676
+ try:
677
+ response = await zhipu_complete_if_cache(
678
+ prompt=prompt,
679
+ system_prompt=system_prompt,
680
+ history_messages=history_messages,
681
+ **kwargs,
682
+ )
683
+
684
+
685
+ try:
686
+ data = json.loads(response)
687
+ return GPTKeywordExtractionFormat(
688
+ high_level_keywords=data.get("high_level_keywords", []),
689
+ low_level_keywords=data.get("low_level_keywords", []),
690
+ )
691
+ except json.JSONDecodeError:
692
+
693
+ match = re.search(r"\{[\s\S]*\}", response)
694
+ if match:
695
+ try:
696
+ data = json.loads(match.group())
697
+ return GPTKeywordExtractionFormat(
698
+ high_level_keywords=data.get("high_level_keywords", []),
699
+ low_level_keywords=data.get("low_level_keywords", []),
700
+ )
701
+ except json.JSONDecodeError:
702
+ pass
703
+
704
+
705
+ logger.warning(
706
+ f"Failed to parse keyword extraction response: {response}"
707
+ )
708
+ return GPTKeywordExtractionFormat(
709
+ high_level_keywords=[], low_level_keywords=[]
710
+ )
711
+ except Exception as e:
712
+ logger.error(f"Error during keyword extraction: {str(e)}")
713
+ return GPTKeywordExtractionFormat(
714
+ high_level_keywords=[], low_level_keywords=[]
715
+ )
716
+ else:
717
+ return await zhipu_complete_if_cache(
718
+ prompt=prompt,
719
+ system_prompt=system_prompt,
720
+ history_messages=history_messages,
721
+ **kwargs,
722
+ )
723
+
724
+
725
+ @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
726
+ @retry(
727
+ stop=stop_after_attempt(3),
728
+ wait=wait_exponential(multiplier=1, min=4, max=60),
729
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
730
+ )
731
+ async def zhipu_embedding(
732
+ texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
733
+ ) -> np.ndarray:
734
+
735
+ try:
736
+ from zhipuai import ZhipuAI
737
+ except ImportError:
738
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
739
+ if api_key:
740
+ client = ZhipuAI(api_key=api_key)
741
+ else:
742
+ client = ZhipuAI()
743
+
744
+ if isinstance(texts, str):
745
+ texts = [texts]
746
+
747
+ embeddings = []
748
+ for text in texts:
749
+ try:
750
+ response = client.embeddings.create(model=model, input=[text], **kwargs)
751
+ embeddings.append(response.data[0].embedding)
752
+ except Exception as e:
753
+ raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
754
+
755
+ return np.array(embeddings)
756
+
757
+
758
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
759
+ @retry(
760
+ stop=stop_after_attempt(3),
761
+ wait=wait_exponential(multiplier=1, min=4, max=60),
762
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
763
+ )
764
+ async def openai_embedding(
765
+ texts: list[str],
766
+ model: str = "text-embedding-3-small",
767
+ base_url="https://api.openai.com/v1",
768
+ api_key="",
769
+ ) -> np.ndarray:
770
+ if api_key:
771
+ os.environ["OPENAI_API_KEY"] = api_key
772
+
773
+ openai_async_client = (
774
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
775
+ )
776
+ response = await openai_async_client.embeddings.create(
777
+ model=model, input=texts, encoding_format="float"
778
+ )
779
+ return np.array([dp.embedding for dp in response.data])
780
+
781
+
782
+ async def fetch_data(url, headers, data):
783
+ async with aiohttp.ClientSession() as session:
784
+ async with session.post(url, headers=headers, json=data) as response:
785
+ response_json = await response.json()
786
+ data_list = response_json.get("data", [])
787
+ return data_list
788
+
789
+
790
+ async def jina_embedding(
791
+ texts: list[str],
792
+ dimensions: int = 1024,
793
+ late_chunking: bool = False,
794
+ base_url: str = None,
795
+ api_key: str = None,
796
+ ) -> np.ndarray:
797
+ if api_key:
798
+ os.environ["JINA_API_KEY"] = api_key
799
+ url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
800
+ headers = {
801
+ "Content-Type": "application/json",
802
+ "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
803
+ }
804
+ data = {
805
+ "model": "jina-embeddings-v3",
806
+ "normalized": True,
807
+ "embedding_type": "float",
808
+ "dimensions": f"{dimensions}",
809
+ "late_chunking": late_chunking,
810
+ "input": texts,
811
+ }
812
+ data_list = await fetch_data(url, headers, data)
813
+ return np.array([dp["embedding"] for dp in data_list])
814
+
815
+
816
+ @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
817
+ @retry(
818
+ stop=stop_after_attempt(3),
819
+ wait=wait_exponential(multiplier=1, min=4, max=60),
820
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
821
+ )
822
+ async def nvidia_openai_embedding(
823
+ texts: list[str],
824
+ model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
825
+ base_url: str = "https://integrate.api.nvidia.com/v1",
826
+ api_key: str = None,
827
+ input_type: str = "passage",
828
+ trunc: str = "NONE",
829
+ encode: str = "float",
830
+ ) -> np.ndarray:
831
+ if api_key:
832
+ os.environ["OPENAI_API_KEY"] = api_key
833
+
834
+ openai_async_client = (
835
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
836
+ )
837
+ response = await openai_async_client.embeddings.create(
838
+ model=model,
839
+ input=texts,
840
+ encoding_format=encode,
841
+ extra_body={"input_type": input_type, "truncate": trunc},
842
+ )
843
+ return np.array([dp.embedding for dp in response.data])
844
+
845
+
846
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
847
+ @retry(
848
+ stop=stop_after_attempt(3),
849
+ wait=wait_exponential(multiplier=1, min=4, max=10),
850
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
851
+ )
852
+ async def azure_openai_embedding(
853
+ texts: list[str],
854
+ model: str = "text-embedding-3-small",
855
+ base_url: str = None,
856
+ api_key: str = None,
857
+ api_version: str = None,
858
+ ) -> np.ndarray:
859
+ if api_key:
860
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
861
+ if base_url:
862
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
863
+ if api_version:
864
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
865
+
866
+ openai_async_client = AsyncAzureOpenAI(
867
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
868
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
869
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
870
+ )
871
+
872
+ response = await openai_async_client.embeddings.create(
873
+ model=model, input=texts, encoding_format="float"
874
+ )
875
+ return np.array([dp.embedding for dp in response.data])
876
+
877
+
878
+ @retry(
879
+ stop=stop_after_attempt(3),
880
+ wait=wait_exponential(multiplier=1, min=4, max=60),
881
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
882
+ )
883
+ async def siliconcloud_embedding(
884
+ texts: list[str],
885
+ model: str = "netease-youdao/bce-embedding-base_v1",
886
+ base_url: str = "https://api.siliconflow.cn/v1/embeddings",
887
+ max_token_size: int = 512,
888
+ api_key: str = None,
889
+ ) -> np.ndarray:
890
+ if api_key and not api_key.startswith("Bearer "):
891
+ api_key = "Bearer " + api_key
892
+
893
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
894
+
895
+ truncate_texts = [text[0:max_token_size] for text in texts]
896
+
897
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
898
+
899
+ base64_strings = []
900
+ async with aiohttp.ClientSession() as session:
901
+ async with session.post(base_url, headers=headers, json=payload) as response:
902
+ content = await response.json()
903
+ if "code" in content:
904
+ raise ValueError(content)
905
+ base64_strings = [item["embedding"] for item in content["data"]]
906
+
907
+ embeddings = []
908
+ for string in base64_strings:
909
+ decode_bytes = base64.b64decode(string)
910
+ n = len(decode_bytes) // 4
911
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
912
+ embeddings.append(float_array)
913
+ return np.array(embeddings)
914
+
915
+
916
+
917
+ async def bedrock_embedding(
918
+ texts: list[str],
919
+ model: str = "amazon.titan-embed-text-v2:0",
920
+ aws_access_key_id=None,
921
+ aws_secret_access_key=None,
922
+ aws_session_token=None,
923
+ ) -> np.ndarray:
924
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
925
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
926
+ )
927
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
928
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
929
+ )
930
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
931
+ "AWS_SESSION_TOKEN", aws_session_token
932
+ )
933
+
934
+ session = aioboto3.Session()
935
+ async with session.client("bedrock-runtime") as bedrock_async_client:
936
+ if (model_provider := model.split(".")[0]) == "amazon":
937
+ embed_texts = []
938
+ for text in texts:
939
+ if "v2" in model:
940
+ body = json.dumps(
941
+ {
942
+ "inputText": text,
943
+
944
+ "embeddingTypes": ["float"],
945
+ }
946
+ )
947
+ elif "v1" in model:
948
+ body = json.dumps({"inputText": text})
949
+ else:
950
+ raise ValueError(f"Model {model} is not supported!")
951
+
952
+ response = await bedrock_async_client.invoke_model(
953
+ modelId=model,
954
+ body=body,
955
+ accept="application/json",
956
+ contentType="application/json",
957
+ )
958
+
959
+ response_body = await response.get("body").json()
960
+
961
+ embed_texts.append(response_body["embedding"])
962
+ elif model_provider == "cohere":
963
+ body = json.dumps(
964
+ {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
965
+ )
966
+
967
+ response = await bedrock_async_client.invoke_model(
968
+ model=model,
969
+ body=body,
970
+ accept="application/json",
971
+ contentType="application/json",
972
+ )
973
+
974
+ response_body = json.loads(response.get("body").read())
975
+
976
+ embed_texts = response_body["embeddings"]
977
+ else:
978
+ raise ValueError(f"Model provider '{model_provider}' is not supported!")
979
+
980
+ return np.array(embed_texts)
981
+
982
+
983
+ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
984
+ device = next(embed_model.parameters()).device
985
+ input_ids = tokenizer(
986
+ texts, return_tensors="pt", padding=True, truncation=True
987
+ ).input_ids.to(device)
988
+ with torch.no_grad():
989
+ outputs = embed_model(input_ids)
990
+ embeddings = outputs.last_hidden_state.mean(dim=1)
991
+ if embeddings.dtype == torch.bfloat16:
992
+ return embeddings.detach().to(torch.float32).cpu().numpy()
993
+ else:
994
+ return embeddings.detach().cpu().numpy()
995
+
996
+
997
+ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
998
+ """
999
+ Deprecated in favor of `embed`.
1000
+ """
1001
+ embed_text = []
1002
+ ollama_client = ollama.Client(**kwargs)
1003
+ for text in texts:
1004
+ data = ollama_client.embeddings(model=embed_model, prompt=text)
1005
+ embed_text.append(data["embedding"])
1006
+
1007
+ return embed_text
1008
+
1009
+
1010
+ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
1011
+ ollama_client = ollama.Client(**kwargs)
1012
+ data = ollama_client.embed(model=embed_model, input=texts)
1013
+ return data["embeddings"]
1014
+
1015
+
1016
+ class Model(BaseModel):
1017
+ """
1018
+ This is a Pydantic model class named 'Model' that is used to define a custom language model.
1019
+
1020
+ Attributes:
1021
+ gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
1022
+ The function should take any argument and return a string.
1023
+ kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
1024
+ This could include parameters such as the model name, API key, etc.
1025
+
1026
+ Example usage:
1027
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
1028
+
1029
+ In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
1030
+ The 'kwargs' dictionary contains the model name and API key to be passed to the function.
1031
+ """
1032
+
1033
+ gen_func: Callable[[Any], str] = Field(
1034
+ ...,
1035
+ description="A function that generates the response from the llm. The response must be a string",
1036
+ )
1037
+ kwargs: Dict[str, Any] = Field(
1038
+ ...,
1039
+ description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
1040
+ )
1041
+
1042
+ class Config:
1043
+ arbitrary_types_allowed = True
1044
+
1045
+
1046
+ class MultiModel:
1047
+ """
1048
+ Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
1049
+ Could also be used for spliting across diffrent models or providers.
1050
+
1051
+ Attributes:
1052
+ models (List[Model]): A list of language models to be used.
1053
+
1054
+ Usage example:
1055
+ ```python
1056
+ models = [
1057
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
1058
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
1059
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
1060
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
1061
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
1062
+ ]
1063
+ multi_model = MultiModel(models)
1064
+ rag = LightRAG(
1065
+ llm_model_func=multi_model.llm_model_func
1066
+ / ..other args
1067
+ )
1068
+ ```
1069
+ """
1070
+
1071
+ def __init__(self, models: List[Model]):
1072
+ self._models = models
1073
+ self._current_model = 0
1074
+
1075
+ def _next_model(self):
1076
+ self._current_model = (self._current_model + 1) % len(self._models)
1077
+ return self._models[self._current_model]
1078
+
1079
+ async def llm_model_func(
1080
+ self, prompt, system_prompt=None, history_messages=[], **kwargs
1081
+ ) -> str:
1082
+ kwargs.pop("model", None)
1083
+ kwargs.pop("keyword_extraction", None)
1084
+ kwargs.pop("mode", None)
1085
+ next_model = self._next_model()
1086
+ args = dict(
1087
+ prompt=prompt,
1088
+ system_prompt=system_prompt,
1089
+ history_messages=history_messages,
1090
+ **kwargs,
1091
+ **next_model.kwargs,
1092
+ )
1093
+
1094
+ return await next_model.gen_func(**args)
1095
+
1096
+
1097
+ if __name__ == "__main__":
1098
+ import asyncio
1099
+
1100
+ async def main():
1101
+ result = await gpt_4o_mini_complete("How are you?")
1102
+ print(result)
1103
+
1104
+ asyncio.run(main())