Magicyuan commited on
Commit
4e1a2a0
·
1 Parent(s): 18ef39c

修复 args_hash在使用常规缓存时候才计算导致embedding缓存时没有计算的bug

Browse files
Files changed (1) hide show
  1. lightrag/llm.py +17 -17
lightrag/llm.py CHANGED
@@ -1,12 +1,16 @@
1
- import os
2
  import copy
3
- from functools import lru_cache
4
  import json
 
 
 
 
 
5
  import aioboto3
6
  import aiohttp
7
  import numpy as np
8
  import ollama
9
-
10
  from openai import (
11
  AsyncOpenAI,
12
  APIConnectionError,
@@ -14,10 +18,7 @@ from openai import (
14
  Timeout,
15
  AsyncAzureOpenAI,
16
  )
17
-
18
- import base64
19
- import struct
20
-
21
  from tenacity import (
22
  retry,
23
  stop_after_attempt,
@@ -25,9 +26,7 @@ from tenacity import (
25
  retry_if_exception_type,
26
  )
27
  from transformers import AutoTokenizer, AutoModelForCausalLM
28
- import torch
29
- from pydantic import BaseModel, Field
30
- from typing import List, Dict, Callable, Any
31
  from .base import BaseKVStorage
32
  from .utils import (
33
  compute_args_hash,
@@ -70,7 +69,7 @@ async def openai_complete_if_cache(
70
  if hashing_kv is not None:
71
  # Calculate args_hash only when using cache
72
  args_hash = compute_args_hash(model, messages)
73
-
74
  # Get embedding cache configuration
75
  embedding_cache_config = hashing_kv.global_config.get(
76
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -167,7 +166,7 @@ async def azure_openai_complete_if_cache(
167
  if hashing_kv is not None:
168
  # Calculate args_hash only when using cache
169
  args_hash = compute_args_hash(model, messages)
170
-
171
  # Get embedding cache configuration
172
  embedding_cache_config = hashing_kv.global_config.get(
173
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -281,7 +280,7 @@ async def bedrock_complete_if_cache(
281
  if hashing_kv is not None:
282
  # Calculate args_hash only when using cache
283
  args_hash = compute_args_hash(model, messages)
284
-
285
  # Get embedding cache configuration
286
  embedding_cache_config = hashing_kv.global_config.get(
287
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -378,7 +377,7 @@ async def hf_model_if_cache(
378
  if hashing_kv is not None:
379
  # Calculate args_hash only when using cache
380
  args_hash = compute_args_hash(model, messages)
381
-
382
  # Get embedding cache configuration
383
  embedding_cache_config = hashing_kv.global_config.get(
384
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -496,7 +495,7 @@ async def ollama_model_if_cache(
496
  if hashing_kv is not None:
497
  # Calculate args_hash only when using cache
498
  args_hash = compute_args_hash(model, messages)
499
-
500
  # Get embedding cache configuration
501
  embedding_cache_config = hashing_kv.global_config.get(
502
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -657,7 +656,7 @@ async def lmdeploy_model_if_cache(
657
  if hashing_kv is not None:
658
  # Calculate args_hash only when using cache
659
  args_hash = compute_args_hash(model, messages)
660
-
661
  # Get embedding cache configuration
662
  embedding_cache_config = hashing_kv.global_config.get(
663
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -867,7 +866,8 @@ async def openai_embedding(
867
  )
868
  async def nvidia_openai_embedding(
869
  texts: list[str],
870
- model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
 
871
  base_url: str = "https://integrate.api.nvidia.com/v1",
872
  api_key: str = None,
873
  input_type: str = "passage", # query for retrieval, passage for embedding
 
1
+ import base64
2
  import copy
 
3
  import json
4
+ import os
5
+ import struct
6
+ from functools import lru_cache
7
+ from typing import List, Dict, Callable, Any
8
+
9
  import aioboto3
10
  import aiohttp
11
  import numpy as np
12
  import ollama
13
+ import torch
14
  from openai import (
15
  AsyncOpenAI,
16
  APIConnectionError,
 
18
  Timeout,
19
  AsyncAzureOpenAI,
20
  )
21
+ from pydantic import BaseModel, Field
 
 
 
22
  from tenacity import (
23
  retry,
24
  stop_after_attempt,
 
26
  retry_if_exception_type,
27
  )
28
  from transformers import AutoTokenizer, AutoModelForCausalLM
29
+
 
 
30
  from .base import BaseKVStorage
31
  from .utils import (
32
  compute_args_hash,
 
69
  if hashing_kv is not None:
70
  # Calculate args_hash only when using cache
71
  args_hash = compute_args_hash(model, messages)
72
+
73
  # Get embedding cache configuration
74
  embedding_cache_config = hashing_kv.global_config.get(
75
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
 
166
  if hashing_kv is not None:
167
  # Calculate args_hash only when using cache
168
  args_hash = compute_args_hash(model, messages)
169
+
170
  # Get embedding cache configuration
171
  embedding_cache_config = hashing_kv.global_config.get(
172
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
 
280
  if hashing_kv is not None:
281
  # Calculate args_hash only when using cache
282
  args_hash = compute_args_hash(model, messages)
283
+
284
  # Get embedding cache configuration
285
  embedding_cache_config = hashing_kv.global_config.get(
286
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
 
377
  if hashing_kv is not None:
378
  # Calculate args_hash only when using cache
379
  args_hash = compute_args_hash(model, messages)
380
+
381
  # Get embedding cache configuration
382
  embedding_cache_config = hashing_kv.global_config.get(
383
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
 
495
  if hashing_kv is not None:
496
  # Calculate args_hash only when using cache
497
  args_hash = compute_args_hash(model, messages)
498
+
499
  # Get embedding cache configuration
500
  embedding_cache_config = hashing_kv.global_config.get(
501
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
 
656
  if hashing_kv is not None:
657
  # Calculate args_hash only when using cache
658
  args_hash = compute_args_hash(model, messages)
659
+
660
  # Get embedding cache configuration
661
  embedding_cache_config = hashing_kv.global_config.get(
662
  "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
 
866
  )
867
  async def nvidia_openai_embedding(
868
  texts: list[str],
869
+ model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
870
+ # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
871
  base_url: str = "https://integrate.api.nvidia.com/v1",
872
  api_key: str = None,
873
  input_type: str = "passage", # query for retrieval, passage for embedding