Magicyuan commited on
Commit
abffd90
·
1 Parent(s): 730e728

缓存计算函数迁移到工具类

Browse files
Files changed (1) hide show
  1. lightrag/utils.py +69 -1
lightrag/utils.py CHANGED
@@ -9,7 +9,7 @@ import re
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
- from typing import Any, Union, List
13
  import xml.etree.ElementTree as ET
14
 
15
  import numpy as np
@@ -390,3 +390,71 @@ def dequantize_embedding(
390
  """Restore quantized embedding"""
391
  scale = (max_val - min_val) / (2**bits - 1)
392
  return (quantized * scale + min_val).astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
+ from typing import Any, Union, List, Optional
13
  import xml.etree.ElementTree as ET
14
 
15
  import numpy as np
 
390
  """Restore quantized embedding"""
391
  scale = (max_val - min_val) / (2**bits - 1)
392
  return (quantized * scale + min_val).astype(np.float32)
393
+
394
+ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
395
+ """Generic cache handling function"""
396
+ if hashing_kv is None:
397
+ return None, None, None, None
398
+
399
+ # Get embedding cache configuration
400
+ embedding_cache_config = hashing_kv.global_config.get(
401
+ "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
402
+ )
403
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
404
+
405
+ quantized = min_val = max_val = None
406
+ if is_embedding_cache_enabled:
407
+ # Use embedding cache
408
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
409
+ current_embedding = await embedding_model_func([prompt])
410
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
411
+ best_cached_response = await get_best_cached_response(
412
+ hashing_kv,
413
+ current_embedding[0],
414
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
415
+ mode=mode,
416
+ )
417
+ if best_cached_response is not None:
418
+ return best_cached_response, None, None, None
419
+ else:
420
+ # Use regular cache
421
+ mode_cache = await hashing_kv.get_by_id(mode) or {}
422
+ if args_hash in mode_cache:
423
+ return mode_cache[args_hash]["return"], None, None, None
424
+
425
+ return None, quantized, min_val, max_val
426
+
427
+
428
+ @dataclass
429
+ class CacheData:
430
+ args_hash: str
431
+ content: str
432
+ model: str
433
+ prompt: str
434
+ quantized: Optional[np.ndarray] = None
435
+ min_val: Optional[float] = None
436
+ max_val: Optional[float] = None
437
+ mode: str = "default"
438
+
439
+
440
+ async def save_to_cache(hashing_kv, cache_data: CacheData):
441
+ if hashing_kv is None:
442
+ return
443
+
444
+ mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
445
+
446
+ mode_cache[cache_data.args_hash] = {
447
+ "return": cache_data.content,
448
+ "model": cache_data.model,
449
+ "embedding": cache_data.quantized.tobytes().hex()
450
+ if cache_data.quantized is not None
451
+ else None,
452
+ "embedding_shape": cache_data.quantized.shape
453
+ if cache_data.quantized is not None
454
+ else None,
455
+ "embedding_min": cache_data.min_val,
456
+ "embedding_max": cache_data.max_val,
457
+ "original_prompt": cache_data.prompt,
458
+ }
459
+
460
+ await hashing_kv.upsert({cache_data.mode: mode_cache})