partoneplay commited on
Commit
ce8e264
·
2 Parent(s): f4c9977 b6dce7a

Merge remote-tracking branch 'origin/main' and fix syntax

Browse files
README.md CHANGED
@@ -596,11 +596,7 @@ if __name__ == "__main__":
596
  | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
597
  | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
598
  | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
599
- | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters:
600
- - `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
601
- - `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
602
-
603
- Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |
604
 
605
  ## API Server Implementation
606
 
 
596
  | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
597
  | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
598
  | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
599
+ | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
 
 
 
 
600
 
601
  ## API Server Implementation
602
 
examples/graph_visual_with_html.py CHANGED
@@ -11,9 +11,16 @@ net = Network(height="100vh", notebook=True)
11
  # Convert NetworkX graph to Pyvis network
12
  net.from_nx(G)
13
 
14
- # Add colors to nodes
15
  for node in net.nodes:
16
  node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
 
 
 
 
 
 
 
17
 
18
  # Save and display the network
19
  net.show("knowledge_graph.html")
 
11
  # Convert NetworkX graph to Pyvis network
12
  net.from_nx(G)
13
 
14
+ # Add colors and title to nodes
15
  for node in net.nodes:
16
  node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
+ if "description" in node:
18
+ node["title"] = node["description"]
19
+
20
+ # Add title to edges
21
+ for edge in net.edges:
22
+ if "description" in edge:
23
+ edge["title"] = edge["description"]
24
 
25
  # Save and display the network
26
  net.show("knowledge_graph.html")
examples/lightrag_jinaai_demo.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from lightrag import LightRAG, QueryParam
3
+ from lightrag.utils import EmbeddingFunc
4
+ from lightrag.llm import jina_embedding, openai_complete_if_cache
5
+ import os
6
+ import asyncio
7
+
8
+
9
+ async def embedding_func(texts: list[str]) -> np.ndarray:
10
+ return await jina_embedding(texts, api_key="YourJinaAPIKey")
11
+
12
+
13
+ WORKING_DIR = "./dickens"
14
+
15
+ if not os.path.exists(WORKING_DIR):
16
+ os.mkdir(WORKING_DIR)
17
+
18
+
19
+ async def llm_model_func(
20
+ prompt, system_prompt=None, history_messages=[], **kwargs
21
+ ) -> str:
22
+ return await openai_complete_if_cache(
23
+ "solar-mini",
24
+ prompt,
25
+ system_prompt=system_prompt,
26
+ history_messages=history_messages,
27
+ api_key=os.getenv("UPSTAGE_API_KEY"),
28
+ base_url="https://api.upstage.ai/v1/solar",
29
+ **kwargs,
30
+ )
31
+
32
+
33
+ rag = LightRAG(
34
+ working_dir=WORKING_DIR,
35
+ llm_model_func=llm_model_func,
36
+ embedding_func=EmbeddingFunc(
37
+ embedding_dim=1024, max_token_size=8192, func=embedding_func
38
+ ),
39
+ )
40
+
41
+
42
+ async def lightraginsert(file_path, semaphore):
43
+ async with semaphore:
44
+ try:
45
+ with open(file_path, "r", encoding="utf-8") as f:
46
+ content = f.read()
47
+ except UnicodeDecodeError:
48
+ # If UTF-8 decoding fails, try other encodings
49
+ with open(file_path, "r", encoding="gbk") as f:
50
+ content = f.read()
51
+ await rag.ainsert(content)
52
+
53
+
54
+ async def process_files(directory, concurrency_limit):
55
+ semaphore = asyncio.Semaphore(concurrency_limit)
56
+ tasks = []
57
+ for root, dirs, files in os.walk(directory):
58
+ for f in files:
59
+ file_path = os.path.join(root, f)
60
+ if f.startswith("."):
61
+ continue
62
+ tasks.append(lightraginsert(file_path, semaphore))
63
+ await asyncio.gather(*tasks)
64
+
65
+
66
+ async def main():
67
+ try:
68
+ rag = LightRAG(
69
+ working_dir=WORKING_DIR,
70
+ llm_model_func=llm_model_func,
71
+ embedding_func=EmbeddingFunc(
72
+ embedding_dim=1024,
73
+ max_token_size=8192,
74
+ func=embedding_func,
75
+ ),
76
+ )
77
+
78
+ asyncio.run(process_files(WORKING_DIR, concurrency_limit=4))
79
+
80
+ # Perform naive search
81
+ print(
82
+ await rag.aquery(
83
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
84
+ )
85
+ )
86
+
87
+ # Perform local search
88
+ print(
89
+ await rag.aquery(
90
+ "What are the top themes in this story?", param=QueryParam(mode="local")
91
+ )
92
+ )
93
+
94
+ # Perform global search
95
+ print(
96
+ await rag.aquery(
97
+ "What are the top themes in this story?",
98
+ param=QueryParam(mode="global"),
99
+ )
100
+ )
101
+
102
+ # Perform hybrid search
103
+ print(
104
+ await rag.aquery(
105
+ "What are the top themes in this story?",
106
+ param=QueryParam(mode="hybrid"),
107
+ )
108
+ )
109
+ except Exception as e:
110
+ print(f"An error occurred: {e}")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ asyncio.run(main())
lightrag/lightrag.py CHANGED
@@ -87,7 +87,11 @@ class LightRAG:
87
  )
88
  # Default not to use embedding cache
89
  embedding_cache_config: dict = field(
90
- default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95}
 
 
 
 
91
  )
92
  kv_storage: str = field(default="JsonKVStorage")
93
  vector_storage: str = field(default="NanoVectorDBStorage")
@@ -174,7 +178,6 @@ class LightRAG:
174
  if self.enable_llm_cache
175
  else None
176
  )
177
-
178
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
179
  self.embedding_func
180
  )
@@ -481,6 +484,7 @@ class LightRAG:
481
  self.text_chunks,
482
  param,
483
  asdict(self),
 
484
  )
485
  elif param.mode == "naive":
486
  response = await naive_query(
@@ -489,6 +493,7 @@ class LightRAG:
489
  self.text_chunks,
490
  param,
491
  asdict(self),
 
492
  )
493
  else:
494
  raise ValueError(f"Unknown mode {param.mode}")
 
87
  )
88
  # Default not to use embedding cache
89
  embedding_cache_config: dict = field(
90
+ default_factory=lambda: {
91
+ "enabled": False,
92
+ "similarity_threshold": 0.95,
93
+ "use_llm_check": False,
94
+ }
95
  )
96
  kv_storage: str = field(default="JsonKVStorage")
97
  vector_storage: str = field(default="NanoVectorDBStorage")
 
178
  if self.enable_llm_cache
179
  else None
180
  )
 
181
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
182
  self.embedding_func
183
  )
 
484
  self.text_chunks,
485
  param,
486
  asdict(self),
487
+ hashing_kv=self.llm_response_cache,
488
  )
489
  elif param.mode == "naive":
490
  response = await naive_query(
 
493
  self.text_chunks,
494
  param,
495
  asdict(self),
496
+ hashing_kv=self.llm_response_cache,
497
  )
498
  else:
499
  raise ValueError(f"Unknown mode {param.mode}")
lightrag/llm.py CHANGED
@@ -4,8 +4,7 @@ import json
4
  import os
5
  import struct
6
  from functools import lru_cache
7
- from typing import List, Dict, Callable, Any, Union, Optional
8
- from dataclasses import dataclass
9
  import aioboto3
10
  import aiohttp
11
  import numpy as np
@@ -27,13 +26,9 @@ from tenacity import (
27
  )
28
  from transformers import AutoTokenizer, AutoModelForCausalLM
29
 
30
- from .base import BaseKVStorage
31
  from .utils import (
32
- compute_args_hash,
33
  wrap_embedding_func_with_attrs,
34
  locate_json_string_body_from_string,
35
- quantize_embedding,
36
- get_best_cached_response,
37
  )
38
 
39
  import sys
@@ -66,23 +61,13 @@ async def openai_complete_if_cache(
66
  openai_async_client = (
67
  AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
68
  )
69
-
70
  messages = []
71
  if system_prompt:
72
  messages.append({"role": "system", "content": system_prompt})
73
  messages.extend(history_messages)
74
  messages.append({"role": "user", "content": prompt})
75
 
76
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
77
- # Handle cache
78
- mode = kwargs.pop("mode", "default")
79
- args_hash = compute_args_hash(model, messages)
80
- cached_response, quantized, min_val, max_val = await handle_cache(
81
- hashing_kv, args_hash, prompt, mode
82
- )
83
- if cached_response is not None:
84
- return cached_response
85
-
86
  if "response_format" in kwargs:
87
  response = await openai_async_client.beta.chat.completions.parse(
88
  model=model, messages=messages, **kwargs
@@ -108,22 +93,6 @@ async def openai_complete_if_cache(
108
  content = response.choices[0].message.content
109
  if r"\u" in content:
110
  content = content.encode("utf-8").decode("unicode_escape")
111
-
112
- # Save to cache
113
- await save_to_cache(
114
- hashing_kv,
115
- CacheData(
116
- args_hash=args_hash,
117
- content=content,
118
- model=model,
119
- prompt=prompt,
120
- quantized=quantized,
121
- min_val=min_val,
122
- max_val=max_val,
123
- mode=mode,
124
- ),
125
- )
126
-
127
  return content
128
 
129
 
@@ -154,10 +123,7 @@ async def azure_openai_complete_if_cache(
154
  api_key=os.getenv("AZURE_OPENAI_API_KEY"),
155
  api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
156
  )
157
-
158
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
159
- mode = kwargs.pop("mode", "default")
160
-
161
  messages = []
162
  if system_prompt:
163
  messages.append({"role": "system", "content": system_prompt})
@@ -165,34 +131,11 @@ async def azure_openai_complete_if_cache(
165
  if prompt is not None:
166
  messages.append({"role": "user", "content": prompt})
167
 
168
- # Handle cache
169
- args_hash = compute_args_hash(model, messages)
170
- cached_response, quantized, min_val, max_val = await handle_cache(
171
- hashing_kv, args_hash, prompt, mode
172
- )
173
- if cached_response is not None:
174
- return cached_response
175
-
176
  response = await openai_async_client.chat.completions.create(
177
  model=model, messages=messages, **kwargs
178
  )
179
  content = response.choices[0].message.content
180
 
181
- # Save to cache
182
- await save_to_cache(
183
- hashing_kv,
184
- CacheData(
185
- args_hash=args_hash,
186
- content=content,
187
- model=model,
188
- prompt=prompt,
189
- quantized=quantized,
190
- min_val=min_val,
191
- max_val=max_val,
192
- mode=mode,
193
- ),
194
- )
195
-
196
  return content
197
 
198
 
@@ -224,7 +167,7 @@ async def bedrock_complete_if_cache(
224
  os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
225
  "AWS_SESSION_TOKEN", aws_session_token
226
  )
227
-
228
  # Fix message history format
229
  messages = []
230
  for history_message in history_messages:
@@ -234,15 +177,6 @@ async def bedrock_complete_if_cache(
234
 
235
  # Add user prompt
236
  messages.append({"role": "user", "content": [{"text": prompt}]})
237
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
238
- # Handle cache
239
- mode = kwargs.pop("mode", "default")
240
- args_hash = compute_args_hash(model, messages)
241
- cached_response, quantized, min_val, max_val = await handle_cache(
242
- hashing_kv, args_hash, prompt, mode
243
- )
244
- if cached_response is not None:
245
- return cached_response
246
 
247
  # Initialize Converse API arguments
248
  args = {"modelId": model, "messages": messages}
@@ -265,15 +199,6 @@ async def bedrock_complete_if_cache(
265
  args["inferenceConfig"][inference_params_map.get(param, param)] = (
266
  kwargs.pop(param)
267
  )
268
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
269
- # Handle cache
270
- mode = kwargs.pop("mode", "default")
271
- args_hash = compute_args_hash(model, messages)
272
- cached_response, quantized, min_val, max_val = await handle_cache(
273
- hashing_kv, args_hash, prompt, mode
274
- )
275
- if cached_response is not None:
276
- return cached_response
277
 
278
  # Call model via Converse API
279
  session = aioboto3.Session()
@@ -283,21 +208,6 @@ async def bedrock_complete_if_cache(
283
  except Exception as e:
284
  raise BedrockError(e)
285
 
286
- # Save to cache
287
- await save_to_cache(
288
- hashing_kv,
289
- CacheData(
290
- args_hash=args_hash,
291
- content=response["output"]["message"]["content"][0]["text"],
292
- model=model,
293
- prompt=prompt,
294
- quantized=quantized,
295
- min_val=min_val,
296
- max_val=max_val,
297
- mode=mode,
298
- ),
299
- )
300
-
301
  return response["output"]["message"]["content"][0]["text"]
302
 
303
 
@@ -329,22 +239,12 @@ async def hf_model_if_cache(
329
  ) -> str:
330
  model_name = model
331
  hf_model, hf_tokenizer = initialize_hf_model(model_name)
332
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
333
  messages = []
334
  if system_prompt:
335
  messages.append({"role": "system", "content": system_prompt})
336
  messages.extend(history_messages)
337
  messages.append({"role": "user", "content": prompt})
338
-
339
- # Handle cache
340
- mode = kwargs.pop("mode", "default")
341
- args_hash = compute_args_hash(model, messages)
342
- cached_response, quantized, min_val, max_val = await handle_cache(
343
- hashing_kv, args_hash, prompt, mode
344
- )
345
- if cached_response is not None:
346
- return cached_response
347
-
348
  input_prompt = ""
349
  try:
350
  input_prompt = hf_tokenizer.apply_chat_template(
@@ -389,21 +289,6 @@ async def hf_model_if_cache(
389
  output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
390
  )
391
 
392
- # Save to cache
393
- await save_to_cache(
394
- hashing_kv,
395
- CacheData(
396
- args_hash=args_hash,
397
- content=response_text,
398
- model=model,
399
- prompt=prompt,
400
- quantized=quantized,
401
- min_val=min_val,
402
- max_val=max_val,
403
- mode=mode,
404
- ),
405
- )
406
-
407
  return response_text
408
 
409
 
@@ -424,25 +309,14 @@ async def ollama_model_if_cache(
424
  # kwargs.pop("response_format", None) # allow json
425
  host = kwargs.pop("host", None)
426
  timeout = kwargs.pop("timeout", None)
427
-
428
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
429
  messages = []
430
  if system_prompt:
431
  messages.append({"role": "system", "content": system_prompt})
432
-
433
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
434
  messages.extend(history_messages)
435
  messages.append({"role": "user", "content": prompt})
436
 
437
- # Handle cache
438
- mode = kwargs.pop("mode", "default")
439
- args_hash = compute_args_hash(model, messages)
440
- cached_response, quantized, min_val, max_val = await handle_cache(
441
- hashing_kv, args_hash, prompt, mode
442
- )
443
- if cached_response is not None:
444
- return cached_response
445
-
446
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
447
  if stream:
448
  """cannot cache stream response"""
@@ -453,22 +327,7 @@ async def ollama_model_if_cache(
453
 
454
  return inner()
455
  else:
456
- result = response["message"]["content"]
457
- # Save to cache
458
- await save_to_cache(
459
- hashing_kv,
460
- CacheData(
461
- args_hash=args_hash,
462
- content=result,
463
- model=model,
464
- prompt=prompt,
465
- quantized=quantized,
466
- min_val=min_val,
467
- max_val=max_val,
468
- mode=mode,
469
- ),
470
- )
471
- return result
472
 
473
 
474
  @lru_cache(maxsize=1)
@@ -543,7 +402,7 @@ async def lmdeploy_model_if_cache(
543
  from lmdeploy import version_info, GenerationConfig
544
  except Exception:
545
  raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
546
-
547
  kwargs.pop("response_format", None)
548
  max_new_tokens = kwargs.pop("max_tokens", 512)
549
  tp = kwargs.pop("tp", 1)
@@ -575,19 +434,9 @@ async def lmdeploy_model_if_cache(
575
  if system_prompt:
576
  messages.append({"role": "system", "content": system_prompt})
577
 
578
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
579
  messages.extend(history_messages)
580
  messages.append({"role": "user", "content": prompt})
581
 
582
- # Handle cache
583
- mode = kwargs.pop("mode", "default")
584
- args_hash = compute_args_hash(model, messages)
585
- cached_response, quantized, min_val, max_val = await handle_cache(
586
- hashing_kv, args_hash, prompt, mode
587
- )
588
- if cached_response is not None:
589
- return cached_response
590
-
591
  gen_config = GenerationConfig(
592
  skip_special_tokens=skip_special_tokens,
593
  max_new_tokens=max_new_tokens,
@@ -603,22 +452,6 @@ async def lmdeploy_model_if_cache(
603
  session_id=1,
604
  ):
605
  response += res.response
606
-
607
- # Save to cache
608
- await save_to_cache(
609
- hashing_kv,
610
- CacheData(
611
- args_hash=args_hash,
612
- content=response,
613
- model=model,
614
- prompt=prompt,
615
- quantized=quantized,
616
- min_val=min_val,
617
- max_val=max_val,
618
- mode=mode,
619
- ),
620
- )
621
-
622
  return response
623
 
624
 
@@ -779,6 +612,40 @@ async def openai_embedding(
779
  return np.array([dp.embedding for dp in response.data])
780
 
781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
  @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
783
  @retry(
784
  stop=stop_after_attempt(3),
@@ -1064,77 +931,6 @@ class MultiModel:
1064
  return await next_model.gen_func(**args)
1065
 
1066
 
1067
- async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
1068
- """Generic cache handling function"""
1069
- if hashing_kv is None:
1070
- return None, None, None, None
1071
-
1072
- # Get embedding cache configuration
1073
- embedding_cache_config = hashing_kv.global_config.get(
1074
- "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
1075
- )
1076
- is_embedding_cache_enabled = embedding_cache_config["enabled"]
1077
-
1078
- quantized = min_val = max_val = None
1079
- if is_embedding_cache_enabled:
1080
- # Use embedding cache
1081
- embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
1082
- current_embedding = await embedding_model_func([prompt])
1083
- quantized, min_val, max_val = quantize_embedding(current_embedding[0])
1084
- best_cached_response = await get_best_cached_response(
1085
- hashing_kv,
1086
- current_embedding[0],
1087
- similarity_threshold=embedding_cache_config["similarity_threshold"],
1088
- mode=mode,
1089
- )
1090
- if best_cached_response is not None:
1091
- return best_cached_response, None, None, None
1092
- else:
1093
- # Use regular cache
1094
- mode_cache = await hashing_kv.get_by_id(mode) or {}
1095
- if args_hash in mode_cache:
1096
- return mode_cache[args_hash]["return"], None, None, None
1097
-
1098
- return None, quantized, min_val, max_val
1099
-
1100
-
1101
- @dataclass
1102
- class CacheData:
1103
- args_hash: str
1104
- content: str
1105
- model: str
1106
- prompt: str
1107
- quantized: Optional[np.ndarray] = None
1108
- min_val: Optional[float] = None
1109
- max_val: Optional[float] = None
1110
- mode: str = "default"
1111
-
1112
-
1113
- async def save_to_cache(hashing_kv, cache_data: CacheData):
1114
- if hashing_kv is None:
1115
- return
1116
-
1117
- mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
1118
-
1119
- mode_cache[cache_data.args_hash] = {
1120
- "return": cache_data.content,
1121
- "model": cache_data.model,
1122
- "embedding": (
1123
- cache_data.quantized.tobytes().hex()
1124
- if cache_data.quantized is not None
1125
- else None
1126
- ),
1127
- "embedding_shape": (
1128
- cache_data.quantized.shape if cache_data.quantized is not None else None
1129
- ),
1130
- "embedding_min": cache_data.min_val,
1131
- "embedding_max": cache_data.max_val,
1132
- "original_prompt": cache_data.prompt,
1133
- }
1134
-
1135
- await hashing_kv.upsert({cache_data.mode: mode_cache})
1136
-
1137
-
1138
  if __name__ == "__main__":
1139
  import asyncio
1140
 
 
4
  import os
5
  import struct
6
  from functools import lru_cache
7
+ from typing import List, Dict, Callable, Any, Union
 
8
  import aioboto3
9
  import aiohttp
10
  import numpy as np
 
26
  )
27
  from transformers import AutoTokenizer, AutoModelForCausalLM
28
 
 
29
  from .utils import (
 
30
  wrap_embedding_func_with_attrs,
31
  locate_json_string_body_from_string,
 
 
32
  )
33
 
34
  import sys
 
61
  openai_async_client = (
62
  AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
63
  )
64
+ kwargs.pop("hashing_kv", None)
65
  messages = []
66
  if system_prompt:
67
  messages.append({"role": "system", "content": system_prompt})
68
  messages.extend(history_messages)
69
  messages.append({"role": "user", "content": prompt})
70
 
 
 
 
 
 
 
 
 
 
 
71
  if "response_format" in kwargs:
72
  response = await openai_async_client.beta.chat.completions.parse(
73
  model=model, messages=messages, **kwargs
 
93
  content = response.choices[0].message.content
94
  if r"\u" in content:
95
  content = content.encode("utf-8").decode("unicode_escape")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  return content
97
 
98
 
 
123
  api_key=os.getenv("AZURE_OPENAI_API_KEY"),
124
  api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
125
  )
126
+ kwargs.pop("hashing_kv", None)
 
 
 
127
  messages = []
128
  if system_prompt:
129
  messages.append({"role": "system", "content": system_prompt})
 
131
  if prompt is not None:
132
  messages.append({"role": "user", "content": prompt})
133
 
 
 
 
 
 
 
 
 
134
  response = await openai_async_client.chat.completions.create(
135
  model=model, messages=messages, **kwargs
136
  )
137
  content = response.choices[0].message.content
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  return content
140
 
141
 
 
167
  os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
168
  "AWS_SESSION_TOKEN", aws_session_token
169
  )
170
+ kwargs.pop("hashing_kv", None)
171
  # Fix message history format
172
  messages = []
173
  for history_message in history_messages:
 
177
 
178
  # Add user prompt
179
  messages.append({"role": "user", "content": [{"text": prompt}]})
 
 
 
 
 
 
 
 
 
180
 
181
  # Initialize Converse API arguments
182
  args = {"modelId": model, "messages": messages}
 
199
  args["inferenceConfig"][inference_params_map.get(param, param)] = (
200
  kwargs.pop(param)
201
  )
 
 
 
 
 
 
 
 
 
202
 
203
  # Call model via Converse API
204
  session = aioboto3.Session()
 
208
  except Exception as e:
209
  raise BedrockError(e)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  return response["output"]["message"]["content"][0]["text"]
212
 
213
 
 
239
  ) -> str:
240
  model_name = model
241
  hf_model, hf_tokenizer = initialize_hf_model(model_name)
 
242
  messages = []
243
  if system_prompt:
244
  messages.append({"role": "system", "content": system_prompt})
245
  messages.extend(history_messages)
246
  messages.append({"role": "user", "content": prompt})
247
+ kwargs.pop("hashing_kv", None)
 
 
 
 
 
 
 
 
 
248
  input_prompt = ""
249
  try:
250
  input_prompt = hf_tokenizer.apply_chat_template(
 
289
  output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
290
  )
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  return response_text
293
 
294
 
 
309
  # kwargs.pop("response_format", None) # allow json
310
  host = kwargs.pop("host", None)
311
  timeout = kwargs.pop("timeout", None)
312
+ kwargs.pop("hashing_kv", None)
313
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
314
  messages = []
315
  if system_prompt:
316
  messages.append({"role": "system", "content": system_prompt})
 
 
317
  messages.extend(history_messages)
318
  messages.append({"role": "user", "content": prompt})
319
 
 
 
 
 
 
 
 
 
 
320
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
321
  if stream:
322
  """cannot cache stream response"""
 
327
 
328
  return inner()
329
  else:
330
+ return response["message"]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
 
333
  @lru_cache(maxsize=1)
 
402
  from lmdeploy import version_info, GenerationConfig
403
  except Exception:
404
  raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
405
+ kwargs.pop("hashing_kv", None)
406
  kwargs.pop("response_format", None)
407
  max_new_tokens = kwargs.pop("max_tokens", 512)
408
  tp = kwargs.pop("tp", 1)
 
434
  if system_prompt:
435
  messages.append({"role": "system", "content": system_prompt})
436
 
 
437
  messages.extend(history_messages)
438
  messages.append({"role": "user", "content": prompt})
439
 
 
 
 
 
 
 
 
 
 
440
  gen_config = GenerationConfig(
441
  skip_special_tokens=skip_special_tokens,
442
  max_new_tokens=max_new_tokens,
 
452
  session_id=1,
453
  ):
454
  response += res.response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  return response
456
 
457
 
 
612
  return np.array([dp.embedding for dp in response.data])
613
 
614
 
615
+ async def fetch_data(url, headers, data):
616
+ async with aiohttp.ClientSession() as session:
617
+ async with session.post(url, headers=headers, json=data) as response:
618
+ response_json = await response.json()
619
+ data_list = response_json.get("data", [])
620
+ return data_list
621
+
622
+
623
+ async def jina_embedding(
624
+ texts: list[str],
625
+ dimensions: int = 1024,
626
+ late_chunking: bool = False,
627
+ base_url: str = None,
628
+ api_key: str = None,
629
+ ) -> np.ndarray:
630
+ if api_key:
631
+ os.environ["JINA_API_KEY"] = api_key
632
+ url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
633
+ headers = {
634
+ "Content-Type": "application/json",
635
+ "Authorization": f"""Bearer {os.environ["JINA_API_KEY"]}""",
636
+ }
637
+ data = {
638
+ "model": "jina-embeddings-v3",
639
+ "normalized": True,
640
+ "embedding_type": "float",
641
+ "dimensions": f"{dimensions}",
642
+ "late_chunking": late_chunking,
643
+ "input": texts,
644
+ }
645
+ data_list = await fetch_data(url, headers, data)
646
+ return np.array([dp["embedding"] for dp in data_list])
647
+
648
+
649
  @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
650
  @retry(
651
  stop=stop_after_attempt(3),
 
931
  return await next_model.gen_func(**args)
932
 
933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
934
  if __name__ == "__main__":
935
  import asyncio
936
 
lightrag/operate.py CHANGED
@@ -17,6 +17,10 @@ from .utils import (
17
  split_string_by_multi_markers,
18
  truncate_list_by_token_size,
19
  process_combine_contexts,
 
 
 
 
20
  )
21
  from .base import (
22
  BaseGraphStorage,
@@ -452,8 +456,17 @@ async def kg_query(
452
  text_chunks_db: BaseKVStorage[TextChunkSchema],
453
  query_param: QueryParam,
454
  global_config: dict,
 
455
  ) -> str:
456
- context = None
 
 
 
 
 
 
 
 
457
  example_number = global_config["addon_params"].get("example_number", None)
458
  if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
459
  examples = "\n".join(
@@ -471,12 +484,9 @@ async def kg_query(
471
  return PROMPTS["fail_response"]
472
 
473
  # LLM generate keywords
474
- use_model_func = global_config["llm_model_func"]
475
  kw_prompt_temp = PROMPTS["keywords_extraction"]
476
  kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
477
- result = await use_model_func(
478
- kw_prompt, keyword_extraction=True, mode=query_param.mode
479
- )
480
  logger.info("kw_prompt result:")
481
  print(result)
482
  try:
@@ -537,7 +547,6 @@ async def kg_query(
537
  query,
538
  system_prompt=sys_prompt,
539
  stream=query_param.stream,
540
- mode=query_param.mode,
541
  )
542
  if isinstance(response, str) and len(response) > len(sys_prompt):
543
  response = (
@@ -550,6 +559,20 @@ async def kg_query(
550
  .strip()
551
  )
552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  return response
554
 
555
 
@@ -1013,8 +1036,17 @@ async def naive_query(
1013
  text_chunks_db: BaseKVStorage[TextChunkSchema],
1014
  query_param: QueryParam,
1015
  global_config: dict,
 
1016
  ):
 
1017
  use_model_func = global_config["llm_model_func"]
 
 
 
 
 
 
 
1018
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
1019
  if not len(results):
1020
  return PROMPTS["fail_response"]
@@ -1039,7 +1071,6 @@ async def naive_query(
1039
  response = await use_model_func(
1040
  query,
1041
  system_prompt=sys_prompt,
1042
- mode=query_param.mode,
1043
  )
1044
 
1045
  if len(response) > len(sys_prompt):
@@ -1054,4 +1085,18 @@ async def naive_query(
1054
  .strip()
1055
  )
1056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
  return response
 
17
  split_string_by_multi_markers,
18
  truncate_list_by_token_size,
19
  process_combine_contexts,
20
+ compute_args_hash,
21
+ handle_cache,
22
+ save_to_cache,
23
+ CacheData,
24
  )
25
  from .base import (
26
  BaseGraphStorage,
 
456
  text_chunks_db: BaseKVStorage[TextChunkSchema],
457
  query_param: QueryParam,
458
  global_config: dict,
459
+ hashing_kv: BaseKVStorage = None,
460
  ) -> str:
461
+ # Handle cache
462
+ use_model_func = global_config["llm_model_func"]
463
+ args_hash = compute_args_hash(query_param.mode, query)
464
+ cached_response, quantized, min_val, max_val = await handle_cache(
465
+ hashing_kv, args_hash, query, query_param.mode
466
+ )
467
+ if cached_response is not None:
468
+ return cached_response
469
+
470
  example_number = global_config["addon_params"].get("example_number", None)
471
  if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
472
  examples = "\n".join(
 
484
  return PROMPTS["fail_response"]
485
 
486
  # LLM generate keywords
 
487
  kw_prompt_temp = PROMPTS["keywords_extraction"]
488
  kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
489
+ result = await use_model_func(kw_prompt, keyword_extraction=True)
 
 
490
  logger.info("kw_prompt result:")
491
  print(result)
492
  try:
 
547
  query,
548
  system_prompt=sys_prompt,
549
  stream=query_param.stream,
 
550
  )
551
  if isinstance(response, str) and len(response) > len(sys_prompt):
552
  response = (
 
559
  .strip()
560
  )
561
 
562
+ # Save to cache
563
+ await save_to_cache(
564
+ hashing_kv,
565
+ CacheData(
566
+ args_hash=args_hash,
567
+ content=response,
568
+ prompt=query,
569
+ quantized=quantized,
570
+ min_val=min_val,
571
+ max_val=max_val,
572
+ mode=query_param.mode,
573
+ ),
574
+ )
575
+
576
  return response
577
 
578
 
 
1036
  text_chunks_db: BaseKVStorage[TextChunkSchema],
1037
  query_param: QueryParam,
1038
  global_config: dict,
1039
+ hashing_kv: BaseKVStorage = None,
1040
  ):
1041
+ # Handle cache
1042
  use_model_func = global_config["llm_model_func"]
1043
+ args_hash = compute_args_hash(query_param.mode, query)
1044
+ cached_response, quantized, min_val, max_val = await handle_cache(
1045
+ hashing_kv, args_hash, query, query_param.mode
1046
+ )
1047
+ if cached_response is not None:
1048
+ return cached_response
1049
+
1050
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
1051
  if not len(results):
1052
  return PROMPTS["fail_response"]
 
1071
  response = await use_model_func(
1072
  query,
1073
  system_prompt=sys_prompt,
 
1074
  )
1075
 
1076
  if len(response) > len(sys_prompt):
 
1085
  .strip()
1086
  )
1087
 
1088
+ # Save to cache
1089
+ await save_to_cache(
1090
+ hashing_kv,
1091
+ CacheData(
1092
+ args_hash=args_hash,
1093
+ content=response,
1094
+ prompt=query,
1095
+ quantized=quantized,
1096
+ min_val=min_val,
1097
+ max_val=max_val,
1098
+ mode=query_param.mode,
1099
+ ),
1100
+ )
1101
+
1102
  return response
lightrag/prompt.py CHANGED
@@ -261,3 +261,22 @@ Do not include information where the supporting evidence for it is not provided.
261
 
262
  Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
263
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
263
  """
264
+
265
+ PROMPTS[
266
+ "similarity_check"
267
+ ] = """Please analyze the similarity between these two questions:
268
+
269
+ Question 1: {original_prompt}
270
+ Question 2: {cached_prompt}
271
+
272
+ Please evaluate:
273
+ 1. Whether these two questions are semantically similar
274
+ 2. Whether the answer to Question 2 can be used to answer Question 1
275
+
276
+ Please provide a similarity score between 0 and 1, where:
277
+ 0: Completely unrelated or answer cannot be reused
278
+ 1: Identical and answer can be directly reused
279
+ 0.5: Partially related and answer needs modification to be used
280
+
281
+ Return only a number between 0-1, without any additional content.
282
+ """
lightrag/utils.py CHANGED
@@ -9,12 +9,14 @@ import re
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
- from typing import Any, Union, List
13
  import xml.etree.ElementTree as ET
14
 
15
  import numpy as np
16
  import tiktoken
17
 
 
 
18
  ENCODER = None
19
 
20
  logger = logging.getLogger("lightrag")
@@ -314,6 +316,9 @@ async def get_best_cached_response(
314
  current_embedding,
315
  similarity_threshold=0.95,
316
  mode="default",
 
 
 
317
  ) -> Union[str, None]:
318
  # Get mode-specific cache
319
  mode_cache = await hashing_kv.get_by_id(mode)
@@ -348,6 +353,37 @@ async def get_best_cached_response(
348
  best_cache_id = cache_id
349
 
350
  if best_similarity > similarity_threshold:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  prompt_display = (
352
  best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
353
  )
@@ -390,3 +426,84 @@ def dequantize_embedding(
390
  """Restore quantized embedding"""
391
  scale = (max_val - min_val) / (2**bits - 1)
392
  return (quantized * scale + min_val).astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
+ from typing import Any, Union, List, Optional
13
  import xml.etree.ElementTree as ET
14
 
15
  import numpy as np
16
  import tiktoken
17
 
18
+ from lightrag.prompt import PROMPTS
19
+
20
  ENCODER = None
21
 
22
  logger = logging.getLogger("lightrag")
 
316
  current_embedding,
317
  similarity_threshold=0.95,
318
  mode="default",
319
+ use_llm_check=False,
320
+ llm_func=None,
321
+ original_prompt=None,
322
  ) -> Union[str, None]:
323
  # Get mode-specific cache
324
  mode_cache = await hashing_kv.get_by_id(mode)
 
353
  best_cache_id = cache_id
354
 
355
  if best_similarity > similarity_threshold:
356
+ # If LLM check is enabled and all required parameters are provided
357
+ if use_llm_check and llm_func and original_prompt and best_prompt:
358
+ compare_prompt = PROMPTS["similarity_check"].format(
359
+ original_prompt=original_prompt, cached_prompt=best_prompt
360
+ )
361
+
362
+ try:
363
+ llm_result = await llm_func(compare_prompt)
364
+ llm_result = llm_result.strip()
365
+ llm_similarity = float(llm_result)
366
+
367
+ # Replace vector similarity with LLM similarity score
368
+ best_similarity = llm_similarity
369
+ if best_similarity < similarity_threshold:
370
+ log_data = {
371
+ "event": "llm_check_cache_rejected",
372
+ "original_question": original_prompt[:100] + "..."
373
+ if len(original_prompt) > 100
374
+ else original_prompt,
375
+ "cached_question": best_prompt[:100] + "..."
376
+ if len(best_prompt) > 100
377
+ else best_prompt,
378
+ "similarity_score": round(best_similarity, 4),
379
+ "threshold": similarity_threshold,
380
+ }
381
+ logger.info(json.dumps(log_data, ensure_ascii=False))
382
+ return None
383
+ except Exception as e: # Catch all possible exceptions
384
+ logger.warning(f"LLM similarity check failed: {e}")
385
+ return None # Return None directly when LLM check fails
386
+
387
  prompt_display = (
388
  best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
389
  )
 
426
  """Restore quantized embedding"""
427
  scale = (max_val - min_val) / (2**bits - 1)
428
  return (quantized * scale + min_val).astype(np.float32)
429
+
430
+
431
+ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
432
+ """Generic cache handling function"""
433
+ if hashing_kv is None:
434
+ return None, None, None, None
435
+
436
+ # For naive mode, only use simple cache matching
437
+ if mode == "naive":
438
+ mode_cache = await hashing_kv.get_by_id(mode) or {}
439
+ if args_hash in mode_cache:
440
+ return mode_cache[args_hash]["return"], None, None, None
441
+ return None, None, None, None
442
+
443
+ # Get embedding cache configuration
444
+ embedding_cache_config = hashing_kv.global_config.get(
445
+ "embedding_cache_config",
446
+ {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
447
+ )
448
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
449
+ use_llm_check = embedding_cache_config.get("use_llm_check", False)
450
+
451
+ quantized = min_val = max_val = None
452
+ if is_embedding_cache_enabled:
453
+ # Use embedding cache
454
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
455
+ llm_model_func = hashing_kv.global_config.get("llm_model_func")
456
+
457
+ current_embedding = await embedding_model_func([prompt])
458
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
459
+ best_cached_response = await get_best_cached_response(
460
+ hashing_kv,
461
+ current_embedding[0],
462
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
463
+ mode=mode,
464
+ use_llm_check=use_llm_check,
465
+ llm_func=llm_model_func if use_llm_check else None,
466
+ original_prompt=prompt if use_llm_check else None,
467
+ )
468
+ if best_cached_response is not None:
469
+ return best_cached_response, None, None, None
470
+ else:
471
+ # Use regular cache
472
+ mode_cache = await hashing_kv.get_by_id(mode) or {}
473
+ if args_hash in mode_cache:
474
+ return mode_cache[args_hash]["return"], None, None, None
475
+
476
+ return None, quantized, min_val, max_val
477
+
478
+
479
+ @dataclass
480
+ class CacheData:
481
+ args_hash: str
482
+ content: str
483
+ prompt: str
484
+ quantized: Optional[np.ndarray] = None
485
+ min_val: Optional[float] = None
486
+ max_val: Optional[float] = None
487
+ mode: str = "default"
488
+
489
+
490
+ async def save_to_cache(hashing_kv, cache_data: CacheData):
491
+ if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
492
+ return
493
+
494
+ mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
495
+
496
+ mode_cache[cache_data.args_hash] = {
497
+ "return": cache_data.content,
498
+ "embedding": cache_data.quantized.tobytes().hex()
499
+ if cache_data.quantized is not None
500
+ else None,
501
+ "embedding_shape": cache_data.quantized.shape
502
+ if cache_data.quantized is not None
503
+ else None,
504
+ "embedding_min": cache_data.min_val,
505
+ "embedding_max": cache_data.max_val,
506
+ "original_prompt": cache_data.prompt,
507
+ }
508
+
509
+ await hashing_kv.upsert({cache_data.mode: mode_cache})