修复 args_hash在使用常规缓存时候才计算导致embedding缓存时没有计算的bug
Browse files- lightrag/llm.py +17 -17
lightrag/llm.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
-
import
|
2 |
import copy
|
3 |
-
from functools import lru_cache
|
4 |
import json
|
|
|
|
|
|
|
|
|
|
|
5 |
import aioboto3
|
6 |
import aiohttp
|
7 |
import numpy as np
|
8 |
import ollama
|
9 |
-
|
10 |
from openai import (
|
11 |
AsyncOpenAI,
|
12 |
APIConnectionError,
|
@@ -14,10 +18,7 @@ from openai import (
|
|
14 |
Timeout,
|
15 |
AsyncAzureOpenAI,
|
16 |
)
|
17 |
-
|
18 |
-
import base64
|
19 |
-
import struct
|
20 |
-
|
21 |
from tenacity import (
|
22 |
retry,
|
23 |
stop_after_attempt,
|
@@ -25,9 +26,7 @@ from tenacity import (
|
|
25 |
retry_if_exception_type,
|
26 |
)
|
27 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
28 |
-
|
29 |
-
from pydantic import BaseModel, Field
|
30 |
-
from typing import List, Dict, Callable, Any
|
31 |
from .base import BaseKVStorage
|
32 |
from .utils import (
|
33 |
compute_args_hash,
|
@@ -70,7 +69,7 @@ async def openai_complete_if_cache(
|
|
70 |
if hashing_kv is not None:
|
71 |
# Calculate args_hash only when using cache
|
72 |
args_hash = compute_args_hash(model, messages)
|
73 |
-
|
74 |
# Get embedding cache configuration
|
75 |
embedding_cache_config = hashing_kv.global_config.get(
|
76 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
@@ -167,7 +166,7 @@ async def azure_openai_complete_if_cache(
|
|
167 |
if hashing_kv is not None:
|
168 |
# Calculate args_hash only when using cache
|
169 |
args_hash = compute_args_hash(model, messages)
|
170 |
-
|
171 |
# Get embedding cache configuration
|
172 |
embedding_cache_config = hashing_kv.global_config.get(
|
173 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
@@ -281,7 +280,7 @@ async def bedrock_complete_if_cache(
|
|
281 |
if hashing_kv is not None:
|
282 |
# Calculate args_hash only when using cache
|
283 |
args_hash = compute_args_hash(model, messages)
|
284 |
-
|
285 |
# Get embedding cache configuration
|
286 |
embedding_cache_config = hashing_kv.global_config.get(
|
287 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
@@ -378,7 +377,7 @@ async def hf_model_if_cache(
|
|
378 |
if hashing_kv is not None:
|
379 |
# Calculate args_hash only when using cache
|
380 |
args_hash = compute_args_hash(model, messages)
|
381 |
-
|
382 |
# Get embedding cache configuration
|
383 |
embedding_cache_config = hashing_kv.global_config.get(
|
384 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
@@ -496,7 +495,7 @@ async def ollama_model_if_cache(
|
|
496 |
if hashing_kv is not None:
|
497 |
# Calculate args_hash only when using cache
|
498 |
args_hash = compute_args_hash(model, messages)
|
499 |
-
|
500 |
# Get embedding cache configuration
|
501 |
embedding_cache_config = hashing_kv.global_config.get(
|
502 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
@@ -657,7 +656,7 @@ async def lmdeploy_model_if_cache(
|
|
657 |
if hashing_kv is not None:
|
658 |
# Calculate args_hash only when using cache
|
659 |
args_hash = compute_args_hash(model, messages)
|
660 |
-
|
661 |
# Get embedding cache configuration
|
662 |
embedding_cache_config = hashing_kv.global_config.get(
|
663 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
@@ -867,7 +866,8 @@ async def openai_embedding(
|
|
867 |
)
|
868 |
async def nvidia_openai_embedding(
|
869 |
texts: list[str],
|
870 |
-
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
|
|
871 |
base_url: str = "https://integrate.api.nvidia.com/v1",
|
872 |
api_key: str = None,
|
873 |
input_type: str = "passage", # query for retrieval, passage for embedding
|
|
|
1 |
+
import base64
|
2 |
import copy
|
|
|
3 |
import json
|
4 |
+
import os
|
5 |
+
import struct
|
6 |
+
from functools import lru_cache
|
7 |
+
from typing import List, Dict, Callable, Any
|
8 |
+
|
9 |
import aioboto3
|
10 |
import aiohttp
|
11 |
import numpy as np
|
12 |
import ollama
|
13 |
+
import torch
|
14 |
from openai import (
|
15 |
AsyncOpenAI,
|
16 |
APIConnectionError,
|
|
|
18 |
Timeout,
|
19 |
AsyncAzureOpenAI,
|
20 |
)
|
21 |
+
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
22 |
from tenacity import (
|
23 |
retry,
|
24 |
stop_after_attempt,
|
|
|
26 |
retry_if_exception_type,
|
27 |
)
|
28 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
29 |
+
|
|
|
|
|
30 |
from .base import BaseKVStorage
|
31 |
from .utils import (
|
32 |
compute_args_hash,
|
|
|
69 |
if hashing_kv is not None:
|
70 |
# Calculate args_hash only when using cache
|
71 |
args_hash = compute_args_hash(model, messages)
|
72 |
+
|
73 |
# Get embedding cache configuration
|
74 |
embedding_cache_config = hashing_kv.global_config.get(
|
75 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
|
166 |
if hashing_kv is not None:
|
167 |
# Calculate args_hash only when using cache
|
168 |
args_hash = compute_args_hash(model, messages)
|
169 |
+
|
170 |
# Get embedding cache configuration
|
171 |
embedding_cache_config = hashing_kv.global_config.get(
|
172 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
|
280 |
if hashing_kv is not None:
|
281 |
# Calculate args_hash only when using cache
|
282 |
args_hash = compute_args_hash(model, messages)
|
283 |
+
|
284 |
# Get embedding cache configuration
|
285 |
embedding_cache_config = hashing_kv.global_config.get(
|
286 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
|
377 |
if hashing_kv is not None:
|
378 |
# Calculate args_hash only when using cache
|
379 |
args_hash = compute_args_hash(model, messages)
|
380 |
+
|
381 |
# Get embedding cache configuration
|
382 |
embedding_cache_config = hashing_kv.global_config.get(
|
383 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
|
495 |
if hashing_kv is not None:
|
496 |
# Calculate args_hash only when using cache
|
497 |
args_hash = compute_args_hash(model, messages)
|
498 |
+
|
499 |
# Get embedding cache configuration
|
500 |
embedding_cache_config = hashing_kv.global_config.get(
|
501 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
|
656 |
if hashing_kv is not None:
|
657 |
# Calculate args_hash only when using cache
|
658 |
args_hash = compute_args_hash(model, messages)
|
659 |
+
|
660 |
# Get embedding cache configuration
|
661 |
embedding_cache_config = hashing_kv.global_config.get(
|
662 |
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
|
866 |
)
|
867 |
async def nvidia_openai_embedding(
|
868 |
texts: list[str],
|
869 |
+
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
870 |
+
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
|
871 |
base_url: str = "https://integrate.api.nvidia.com/v1",
|
872 |
api_key: str = None,
|
873 |
input_type: str = "passage", # query for retrieval, passage for embedding
|