omdivyatej
commited on
Commit
·
cb56de7
1
Parent(s):
1f844c6
linting errors
Browse files- examples/lightrag_multi_model_all_modes_demo.py +11 -13
- lightrag/base.py +1 -1
- lightrag/lightrag.py +1 -1
- lightrag/operate.py +23 -5
examples/lightrag_multi_model_all_modes_demo.py
CHANGED
|
@@ -9,6 +9,7 @@ WORKING_DIR = "./lightrag_demo"
|
|
| 9 |
if not os.path.exists(WORKING_DIR):
|
| 10 |
os.mkdir(WORKING_DIR)
|
| 11 |
|
|
|
|
| 12 |
async def initialize_rag():
|
| 13 |
rag = LightRAG(
|
| 14 |
working_dir=WORKING_DIR,
|
|
@@ -21,6 +22,7 @@ async def initialize_rag():
|
|
| 21 |
|
| 22 |
return rag
|
| 23 |
|
|
|
|
| 24 |
def main():
|
| 25 |
# Initialize RAG instance
|
| 26 |
rag = asyncio.run(initialize_rag())
|
|
@@ -33,8 +35,7 @@ def main():
|
|
| 33 |
print("--- NAIVE mode ---")
|
| 34 |
print(
|
| 35 |
rag.query(
|
| 36 |
-
"What are the main themes in this story?",
|
| 37 |
-
param=QueryParam(mode="naive")
|
| 38 |
)
|
| 39 |
)
|
| 40 |
|
|
@@ -42,8 +43,7 @@ def main():
|
|
| 42 |
print("\n--- LOCAL mode ---")
|
| 43 |
print(
|
| 44 |
rag.query(
|
| 45 |
-
"What are the main themes in this story?",
|
| 46 |
-
param=QueryParam(mode="local")
|
| 47 |
)
|
| 48 |
)
|
| 49 |
|
|
@@ -51,8 +51,7 @@ def main():
|
|
| 51 |
print("\n--- GLOBAL mode ---")
|
| 52 |
print(
|
| 53 |
rag.query(
|
| 54 |
-
"What are the main themes in this story?",
|
| 55 |
-
param=QueryParam(mode="global")
|
| 56 |
)
|
| 57 |
)
|
| 58 |
|
|
@@ -60,8 +59,7 @@ def main():
|
|
| 60 |
print("\n--- HYBRID mode ---")
|
| 61 |
print(
|
| 62 |
rag.query(
|
| 63 |
-
"What are the main themes in this story?",
|
| 64 |
-
param=QueryParam(mode="hybrid")
|
| 65 |
)
|
| 66 |
)
|
| 67 |
|
|
@@ -69,8 +67,7 @@ def main():
|
|
| 69 |
print("\n--- MIX mode ---")
|
| 70 |
print(
|
| 71 |
rag.query(
|
| 72 |
-
"What are the main themes in this story?",
|
| 73 |
-
param=QueryParam(mode="mix")
|
| 74 |
)
|
| 75 |
)
|
| 76 |
|
|
@@ -81,10 +78,11 @@ def main():
|
|
| 81 |
"How does the character development reflect Victorian-era attitudes?",
|
| 82 |
param=QueryParam(
|
| 83 |
mode="global",
|
| 84 |
-
model_func=gpt_4o_complete # Override default model with more capable one
|
| 85 |
-
)
|
| 86 |
)
|
| 87 |
)
|
| 88 |
|
|
|
|
| 89 |
if __name__ == "__main__":
|
| 90 |
-
main()
|
|
|
|
| 9 |
if not os.path.exists(WORKING_DIR):
|
| 10 |
os.mkdir(WORKING_DIR)
|
| 11 |
|
| 12 |
+
|
| 13 |
async def initialize_rag():
|
| 14 |
rag = LightRAG(
|
| 15 |
working_dir=WORKING_DIR,
|
|
|
|
| 22 |
|
| 23 |
return rag
|
| 24 |
|
| 25 |
+
|
| 26 |
def main():
|
| 27 |
# Initialize RAG instance
|
| 28 |
rag = asyncio.run(initialize_rag())
|
|
|
|
| 35 |
print("--- NAIVE mode ---")
|
| 36 |
print(
|
| 37 |
rag.query(
|
| 38 |
+
"What are the main themes in this story?", param=QueryParam(mode="naive")
|
|
|
|
| 39 |
)
|
| 40 |
)
|
| 41 |
|
|
|
|
| 43 |
print("\n--- LOCAL mode ---")
|
| 44 |
print(
|
| 45 |
rag.query(
|
| 46 |
+
"What are the main themes in this story?", param=QueryParam(mode="local")
|
|
|
|
| 47 |
)
|
| 48 |
)
|
| 49 |
|
|
|
|
| 51 |
print("\n--- GLOBAL mode ---")
|
| 52 |
print(
|
| 53 |
rag.query(
|
| 54 |
+
"What are the main themes in this story?", param=QueryParam(mode="global")
|
|
|
|
| 55 |
)
|
| 56 |
)
|
| 57 |
|
|
|
|
| 59 |
print("\n--- HYBRID mode ---")
|
| 60 |
print(
|
| 61 |
rag.query(
|
| 62 |
+
"What are the main themes in this story?", param=QueryParam(mode="hybrid")
|
|
|
|
| 63 |
)
|
| 64 |
)
|
| 65 |
|
|
|
|
| 67 |
print("\n--- MIX mode ---")
|
| 68 |
print(
|
| 69 |
rag.query(
|
| 70 |
+
"What are the main themes in this story?", param=QueryParam(mode="mix")
|
|
|
|
| 71 |
)
|
| 72 |
)
|
| 73 |
|
|
|
|
| 78 |
"How does the character development reflect Victorian-era attitudes?",
|
| 79 |
param=QueryParam(
|
| 80 |
mode="global",
|
| 81 |
+
model_func=gpt_4o_complete, # Override default model with more capable one
|
| 82 |
+
),
|
| 83 |
)
|
| 84 |
)
|
| 85 |
|
| 86 |
+
|
| 87 |
if __name__ == "__main__":
|
| 88 |
+
main()
|
lightrag/base.py
CHANGED
|
@@ -84,7 +84,7 @@ class QueryParam:
|
|
| 84 |
|
| 85 |
ids: list[str] | None = None
|
| 86 |
"""List of ids to filter the results."""
|
| 87 |
-
|
| 88 |
model_func: Callable[..., object] | None = None
|
| 89 |
"""Optional override for the LLM model function to use for this specific query.
|
| 90 |
If provided, this will be used instead of the global model function.
|
|
|
|
| 84 |
|
| 85 |
ids: list[str] | None = None
|
| 86 |
"""List of ids to filter the results."""
|
| 87 |
+
|
| 88 |
model_func: Callable[..., object] | None = None
|
| 89 |
"""Optional override for the LLM model function to use for this specific query.
|
| 90 |
If provided, this will be used instead of the global model function.
|
lightrag/lightrag.py
CHANGED
|
@@ -1338,7 +1338,7 @@ class LightRAG:
|
|
| 1338 |
"""
|
| 1339 |
# If a custom model is provided in param, temporarily update global config
|
| 1340 |
global_config = asdict(self)
|
| 1341 |
-
|
| 1342 |
if param.mode in ["local", "global", "hybrid"]:
|
| 1343 |
response = await kg_query(
|
| 1344 |
query.strip(),
|
|
|
|
| 1338 |
"""
|
| 1339 |
# If a custom model is provided in param, temporarily update global config
|
| 1340 |
global_config = asdict(self)
|
| 1341 |
+
|
| 1342 |
if param.mode in ["local", "global", "hybrid"]:
|
| 1343 |
response = await kg_query(
|
| 1344 |
query.strip(),
|
lightrag/operate.py
CHANGED
|
@@ -705,7 +705,11 @@ async def kg_query(
|
|
| 705 |
system_prompt: str | None = None,
|
| 706 |
) -> str | AsyncIterator[str]:
|
| 707 |
# Handle cache
|
| 708 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
| 710 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 711 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
@@ -866,7 +870,9 @@ async def extract_keywords_only(
|
|
| 866 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
| 867 |
|
| 868 |
# 5. Call the LLM for keyword extraction
|
| 869 |
-
use_model_func =
|
|
|
|
|
|
|
| 870 |
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
| 871 |
|
| 872 |
# 6. Parse out JSON from the LLM response
|
|
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
|
|
| 926 |
3. Combining both results for comprehensive answer generation
|
| 927 |
"""
|
| 928 |
# 1. Cache handling
|
| 929 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
args_hash = compute_args_hash("mix", query, cache_type="query")
|
| 931 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 932 |
hashing_kv, args_hash, query, "mix", cache_type="query"
|
|
@@ -1731,7 +1741,11 @@ async def naive_query(
|
|
| 1731 |
system_prompt: str | None = None,
|
| 1732 |
) -> str | AsyncIterator[str]:
|
| 1733 |
# Handle cache
|
| 1734 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1735 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
| 1736 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 1737 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
|
|
| 1850 |
# ---------------------------
|
| 1851 |
# 1) Handle potential cache for query results
|
| 1852 |
# ---------------------------
|
| 1853 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1854 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
| 1855 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 1856 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
|
| 705 |
system_prompt: str | None = None,
|
| 706 |
) -> str | AsyncIterator[str]:
|
| 707 |
# Handle cache
|
| 708 |
+
use_model_func = (
|
| 709 |
+
query_param.model_func
|
| 710 |
+
if query_param.model_func
|
| 711 |
+
else global_config["llm_model_func"]
|
| 712 |
+
)
|
| 713 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
| 714 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 715 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
|
| 870 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
| 871 |
|
| 872 |
# 5. Call the LLM for keyword extraction
|
| 873 |
+
use_model_func = (
|
| 874 |
+
param.model_func if param.model_func else global_config["llm_model_func"]
|
| 875 |
+
)
|
| 876 |
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
| 877 |
|
| 878 |
# 6. Parse out JSON from the LLM response
|
|
|
|
| 932 |
3. Combining both results for comprehensive answer generation
|
| 933 |
"""
|
| 934 |
# 1. Cache handling
|
| 935 |
+
use_model_func = (
|
| 936 |
+
query_param.model_func
|
| 937 |
+
if query_param.model_func
|
| 938 |
+
else global_config["llm_model_func"]
|
| 939 |
+
)
|
| 940 |
args_hash = compute_args_hash("mix", query, cache_type="query")
|
| 941 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 942 |
hashing_kv, args_hash, query, "mix", cache_type="query"
|
|
|
|
| 1741 |
system_prompt: str | None = None,
|
| 1742 |
) -> str | AsyncIterator[str]:
|
| 1743 |
# Handle cache
|
| 1744 |
+
use_model_func = (
|
| 1745 |
+
query_param.model_func
|
| 1746 |
+
if query_param.model_func
|
| 1747 |
+
else global_config["llm_model_func"]
|
| 1748 |
+
)
|
| 1749 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
| 1750 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 1751 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
|
| 1864 |
# ---------------------------
|
| 1865 |
# 1) Handle potential cache for query results
|
| 1866 |
# ---------------------------
|
| 1867 |
+
use_model_func = (
|
| 1868 |
+
query_param.model_func
|
| 1869 |
+
if query_param.model_func
|
| 1870 |
+
else global_config["llm_model_func"]
|
| 1871 |
+
)
|
| 1872 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
| 1873 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 1874 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|