|  | import os | 
					
						
						|  | import asyncio | 
					
						
						|  | import nest_asyncio | 
					
						
						|  |  | 
					
						
						|  | nest_asyncio.apply() | 
					
						
						|  |  | 
					
						
						|  | from lightrag import LightRAG, QueryParam | 
					
						
						|  | from lightrag.llm import ( | 
					
						
						|  | openai_complete_if_cache, | 
					
						
						|  | nvidia_openai_embed, | 
					
						
						|  | ) | 
					
						
						|  | from lightrag.utils import EmbeddingFunc | 
					
						
						|  | import numpy as np | 
					
						
						|  | from lightrag.kg.shared_storage import initialize_pipeline_status | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from lightrag.utils import locate_json_string_body_from_string | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | WORKING_DIR = "./dickens" | 
					
						
						|  |  | 
					
						
						|  | if not os.path.exists(WORKING_DIR): | 
					
						
						|  | os.mkdir(WORKING_DIR) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | NVIDIA_OPENAI_API_KEY = "nvapi-xxxx" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def llm_model_func( | 
					
						
						|  | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | 
					
						
						|  | ) -> str: | 
					
						
						|  | result = await openai_complete_if_cache( | 
					
						
						|  | "nvidia/llama-3.1-nemotron-70b-instruct", | 
					
						
						|  | prompt, | 
					
						
						|  | system_prompt=system_prompt, | 
					
						
						|  | history_messages=history_messages, | 
					
						
						|  | api_key=NVIDIA_OPENAI_API_KEY, | 
					
						
						|  | base_url="https://integrate.api.nvidia.com/v1", | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) | 
					
						
						|  | if keyword_extraction: | 
					
						
						|  | return locate_json_string_body_from_string(result) | 
					
						
						|  | return result | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nvidia_embed_model = "nvidia/nv-embedqa-e5-v5" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def indexing_embedding_func(texts: list[str]) -> np.ndarray: | 
					
						
						|  | return await nvidia_openai_embed( | 
					
						
						|  | texts, | 
					
						
						|  | model=nvidia_embed_model, | 
					
						
						|  |  | 
					
						
						|  | api_key=NVIDIA_OPENAI_API_KEY, | 
					
						
						|  | base_url="https://integrate.api.nvidia.com/v1", | 
					
						
						|  | input_type="passage", | 
					
						
						|  | trunc="END", | 
					
						
						|  | encode="float", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def query_embedding_func(texts: list[str]) -> np.ndarray: | 
					
						
						|  | return await nvidia_openai_embed( | 
					
						
						|  | texts, | 
					
						
						|  | model=nvidia_embed_model, | 
					
						
						|  |  | 
					
						
						|  | api_key=NVIDIA_OPENAI_API_KEY, | 
					
						
						|  | base_url="https://integrate.api.nvidia.com/v1", | 
					
						
						|  | input_type="query", | 
					
						
						|  | trunc="END", | 
					
						
						|  | encode="float", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def get_embedding_dim(): | 
					
						
						|  | test_text = ["This is a test sentence."] | 
					
						
						|  | embedding = await indexing_embedding_func(test_text) | 
					
						
						|  | embedding_dim = embedding.shape[1] | 
					
						
						|  | return embedding_dim | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def test_funcs(): | 
					
						
						|  | result = await llm_model_func("How are you?") | 
					
						
						|  | print("llm_model_func: ", result) | 
					
						
						|  |  | 
					
						
						|  | result = await indexing_embedding_func(["How are you?"]) | 
					
						
						|  | print("embedding_func: ", result) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def initialize_rag(): | 
					
						
						|  | embedding_dimension = await get_embedding_dim() | 
					
						
						|  | print(f"Detected embedding dimension: {embedding_dimension}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | rag = LightRAG( | 
					
						
						|  | working_dir=WORKING_DIR, | 
					
						
						|  | llm_model_func=llm_model_func, | 
					
						
						|  |  | 
					
						
						|  | embedding_func=EmbeddingFunc( | 
					
						
						|  | embedding_dim=embedding_dimension, | 
					
						
						|  | max_token_size=512, | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | func=indexing_embedding_func, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | await rag.initialize_storages() | 
					
						
						|  | await initialize_pipeline_status() | 
					
						
						|  |  | 
					
						
						|  | return rag | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def main(): | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | rag = await initialize_rag() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open("./book.txt", "r", encoding="utf-8") as f: | 
					
						
						|  | await rag.ainsert(f.read()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("==============Naive===============") | 
					
						
						|  | print( | 
					
						
						|  | await rag.aquery( | 
					
						
						|  | "What are the top themes in this story?", param=QueryParam(mode="naive") | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("==============local===============") | 
					
						
						|  | print( | 
					
						
						|  | await rag.aquery( | 
					
						
						|  | "What are the top themes in this story?", param=QueryParam(mode="local") | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("==============global===============") | 
					
						
						|  | print( | 
					
						
						|  | await rag.aquery( | 
					
						
						|  | "What are the top themes in this story?", | 
					
						
						|  | param=QueryParam(mode="global"), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("==============hybrid===============") | 
					
						
						|  | print( | 
					
						
						|  | await rag.aquery( | 
					
						
						|  | "What are the top themes in this story?", | 
					
						
						|  | param=QueryParam(mode="hybrid"), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"An error occurred: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | asyncio.run(main()) | 
					
						
						|  |  |