Add HF Support
Browse files- lightrag/lightrag.py +4 -3
- lightrag/llm.py +82 -2
- lightrag/operate.py +62 -14
    	
        lightrag/lightrag.py
    CHANGED
    
    | @@ -5,7 +5,7 @@ from datetime import datetime | |
| 5 | 
             
            from functools import partial
         | 
| 6 | 
             
            from typing import Type, cast
         | 
| 7 |  | 
| 8 | 
            -
            from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
         | 
| 9 | 
             
            from .operate import (
         | 
| 10 | 
             
                chunking_by_token_size,
         | 
| 11 | 
             
                extract_entities,
         | 
| @@ -77,12 +77,13 @@ class LightRAG: | |
| 77 | 
             
                )
         | 
| 78 |  | 
| 79 | 
             
                # text embedding
         | 
| 80 | 
            -
                embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding | 
| 81 | 
             
                embedding_batch_num: int = 32
         | 
| 82 | 
             
                embedding_func_max_async: int = 16
         | 
| 83 |  | 
| 84 | 
             
                # LLM
         | 
| 85 | 
            -
                llm_model_func: callable = gpt_4o_mini_complete
         | 
|  | |
| 86 | 
             
                llm_model_max_token_size: int = 32768
         | 
| 87 | 
             
                llm_model_max_async: int = 16
         | 
| 88 |  | 
|  | |
| 5 | 
             
            from functools import partial
         | 
| 6 | 
             
            from typing import Type, cast
         | 
| 7 |  | 
| 8 | 
            +
            from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model,hf_embedding
         | 
| 9 | 
             
            from .operate import (
         | 
| 10 | 
             
                chunking_by_token_size,
         | 
| 11 | 
             
                extract_entities,
         | 
|  | |
| 77 | 
             
                )
         | 
| 78 |  | 
| 79 | 
             
                # text embedding
         | 
| 80 | 
            +
                embedding_func: EmbeddingFunc = field(default_factory=lambda: hf_embedding)#openai_embedding
         | 
| 81 | 
             
                embedding_batch_num: int = 32
         | 
| 82 | 
             
                embedding_func_max_async: int = 16
         | 
| 83 |  | 
| 84 | 
             
                # LLM
         | 
| 85 | 
            +
                llm_model_func: callable = hf_model#gpt_4o_mini_complete
         | 
| 86 | 
            +
                llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
         | 
| 87 | 
             
                llm_model_max_token_size: int = 32768
         | 
| 88 | 
             
                llm_model_max_async: int = 16
         | 
| 89 |  | 
    	
        lightrag/llm.py
    CHANGED
    
    | @@ -7,10 +7,12 @@ from tenacity import ( | |
| 7 | 
             
                wait_exponential,
         | 
| 8 | 
             
                retry_if_exception_type,
         | 
| 9 | 
             
            )
         | 
| 10 | 
            -
             | 
|  | |
| 11 | 
             
            from .base import BaseKVStorage
         | 
| 12 | 
             
            from .utils import compute_args_hash, wrap_embedding_func_with_attrs
         | 
| 13 | 
            -
             | 
|  | |
| 14 | 
             
            @retry(
         | 
| 15 | 
             
                stop=stop_after_attempt(3),
         | 
| 16 | 
             
                wait=wait_exponential(multiplier=1, min=4, max=10),
         | 
| @@ -42,6 +44,52 @@ async def openai_complete_if_cache( | |
| 42 | 
             
                    )
         | 
| 43 | 
             
                return response.choices[0].message.content
         | 
| 44 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 45 | 
             
            async def gpt_4o_complete(
         | 
| 46 | 
             
                prompt, system_prompt=None, history_messages=[], **kwargs
         | 
| 47 | 
             
            ) -> str:
         | 
| @@ -65,6 +113,20 @@ async def gpt_4o_mini_complete( | |
| 65 | 
             
                    **kwargs,
         | 
| 66 | 
             
                )
         | 
| 67 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 68 | 
             
            @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
         | 
| 69 | 
             
            @retry(
         | 
| 70 | 
             
                stop=stop_after_attempt(3),
         | 
| @@ -78,6 +140,24 @@ async def openai_embedding(texts: list[str]) -> np.ndarray: | |
| 78 | 
             
                )
         | 
| 79 | 
             
                return np.array([dp.embedding for dp in response.data])
         | 
| 80 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 81 | 
             
            if __name__ == "__main__":
         | 
| 82 | 
             
                import asyncio
         | 
| 83 |  | 
|  | |
| 7 | 
             
                wait_exponential,
         | 
| 8 | 
             
                retry_if_exception_type,
         | 
| 9 | 
             
            )
         | 
| 10 | 
            +
            from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
             
            from .base import BaseKVStorage
         | 
| 13 | 
             
            from .utils import compute_args_hash, wrap_embedding_func_with_attrs
         | 
| 14 | 
            +
            import copy
         | 
| 15 | 
            +
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
| 16 | 
             
            @retry(
         | 
| 17 | 
             
                stop=stop_after_attempt(3),
         | 
| 18 | 
             
                wait=wait_exponential(multiplier=1, min=4, max=10),
         | 
|  | |
| 44 | 
             
                    )
         | 
| 45 | 
             
                return response.choices[0].message.content
         | 
| 46 |  | 
| 47 | 
            +
            async def hf_model_if_cache(
         | 
| 48 | 
            +
                model, prompt, system_prompt=None, history_messages=[], **kwargs
         | 
| 49 | 
            +
            ) -> str:
         | 
| 50 | 
            +
                model_name = model
         | 
| 51 | 
            +
                hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
         | 
| 52 | 
            +
                if hf_tokenizer.pad_token == None:
         | 
| 53 | 
            +
                    # print("use eos token")
         | 
| 54 | 
            +
                    hf_tokenizer.pad_token = hf_tokenizer.eos_token
         | 
| 55 | 
            +
                hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
         | 
| 56 | 
            +
                hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
         | 
| 57 | 
            +
                messages = []
         | 
| 58 | 
            +
                if system_prompt:
         | 
| 59 | 
            +
                    messages.append({"role": "system", "content": system_prompt})
         | 
| 60 | 
            +
                messages.extend(history_messages)
         | 
| 61 | 
            +
                messages.append({"role": "user", "content": prompt})
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                if hashing_kv is not None:
         | 
| 64 | 
            +
                    args_hash = compute_args_hash(model, messages)
         | 
| 65 | 
            +
                    if_cache_return = await hashing_kv.get_by_id(args_hash)
         | 
| 66 | 
            +
                    if if_cache_return is not None:
         | 
| 67 | 
            +
                        return if_cache_return["return"]
         | 
| 68 | 
            +
                input_prompt = ''
         | 
| 69 | 
            +
                try:
         | 
| 70 | 
            +
                    input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)    
         | 
| 71 | 
            +
                except:
         | 
| 72 | 
            +
                    try:
         | 
| 73 | 
            +
                        ori_message = copy.deepcopy(messages)
         | 
| 74 | 
            +
                        if messages[0]['role'] == "system":
         | 
| 75 | 
            +
                            messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
         | 
| 76 | 
            +
                            messages = messages[1:]
         | 
| 77 | 
            +
                            input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)    
         | 
| 78 | 
            +
                    except:      
         | 
| 79 | 
            +
                        len_message = len(ori_message)
         | 
| 80 | 
            +
                        for msgid in range(len_message):
         | 
| 81 | 
            +
                            input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
         | 
| 84 | 
            +
                output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
         | 
| 85 | 
            +
                response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
         | 
| 86 | 
            +
                if hashing_kv is not None:
         | 
| 87 | 
            +
                    await hashing_kv.upsert(
         | 
| 88 | 
            +
                        {args_hash: {"return": response_text, "model": model}}
         | 
| 89 | 
            +
                    )
         | 
| 90 | 
            +
                return response_text
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
             
            async def gpt_4o_complete(
         | 
| 94 | 
             
                prompt, system_prompt=None, history_messages=[], **kwargs
         | 
| 95 | 
             
            ) -> str:
         | 
|  | |
| 113 | 
             
                    **kwargs,
         | 
| 114 | 
             
                )
         | 
| 115 |  | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            async def hf_model(
         | 
| 119 | 
            +
                prompt, system_prompt=None, history_messages=[], **kwargs
         | 
| 120 | 
            +
            ) -> str:
         | 
| 121 | 
            +
                input_string = kwargs.get('model_name', 'google/gemma-2-2b-it')
         | 
| 122 | 
            +
                return await hf_model_if_cache(
         | 
| 123 | 
            +
                    input_string,
         | 
| 124 | 
            +
                    prompt,
         | 
| 125 | 
            +
                    system_prompt=system_prompt,
         | 
| 126 | 
            +
                    history_messages=history_messages,
         | 
| 127 | 
            +
                    **kwargs,
         | 
| 128 | 
            +
                )
         | 
| 129 | 
            +
             | 
| 130 | 
             
            @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
         | 
| 131 | 
             
            @retry(
         | 
| 132 | 
             
                stop=stop_after_attempt(3),
         | 
|  | |
| 140 | 
             
                )
         | 
| 141 | 
             
                return np.array([dp.embedding for dp in response.data])
         | 
| 142 |  | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            global EMBED_MODEL
         | 
| 146 | 
            +
            global tokenizer
         | 
| 147 | 
            +
            EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
         | 
| 148 | 
            +
            tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
         | 
| 149 | 
            +
            @wrap_embedding_func_with_attrs(
         | 
| 150 | 
            +
                embedding_dim=384,
         | 
| 151 | 
            +
                max_token_size=5000,
         | 
| 152 | 
            +
            )
         | 
| 153 | 
            +
            async def hf_embedding(texts: list[str]) -> np.ndarray:
         | 
| 154 | 
            +
                input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
         | 
| 155 | 
            +
                with torch.no_grad():
         | 
| 156 | 
            +
                    outputs = EMBED_MODEL(input_ids)
         | 
| 157 | 
            +
                    embeddings = outputs.last_hidden_state.mean(dim=1)
         | 
| 158 | 
            +
                return embeddings.detach().numpy()
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
             
            if __name__ == "__main__":
         | 
| 162 | 
             
                import asyncio
         | 
| 163 |  | 
    	
        lightrag/operate.py
    CHANGED
    
    | @@ -3,7 +3,7 @@ import json | |
| 3 | 
             
            import re
         | 
| 4 | 
             
            from typing import Union
         | 
| 5 | 
             
            from collections import Counter, defaultdict
         | 
| 6 | 
            -
             | 
| 7 | 
             
            from .utils import (
         | 
| 8 | 
             
                logger,
         | 
| 9 | 
             
                clean_str,
         | 
| @@ -398,10 +398,15 @@ async def local_query( | |
| 398 | 
             
                    keywords = keywords_data.get("low_level_keywords", [])
         | 
| 399 | 
             
                    keywords = ', '.join(keywords)
         | 
| 400 | 
             
                except json.JSONDecodeError as e:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 401 | 
             
                    # Handle parsing error
         | 
| 402 | 
            -
                     | 
| 403 | 
            -
             | 
| 404 | 
            -
             | 
| 405 | 
             
                context = await _build_local_query_context(
         | 
| 406 | 
             
                    keywords,
         | 
| 407 | 
             
                    knowledge_graph_inst,
         | 
| @@ -421,6 +426,9 @@ async def local_query( | |
| 421 | 
             
                    query,
         | 
| 422 | 
             
                    system_prompt=sys_prompt,
         | 
| 423 | 
             
                )
         | 
|  | |
|  | |
|  | |
| 424 | 
             
                return response
         | 
| 425 |  | 
| 426 | 
             
            async def _build_local_query_context(
         | 
| @@ -617,9 +625,16 @@ async def global_query( | |
| 617 | 
             
                    keywords = keywords_data.get("high_level_keywords", [])
         | 
| 618 | 
             
                    keywords = ', '.join(keywords)
         | 
| 619 | 
             
                except json.JSONDecodeError as e:
         | 
| 620 | 
            -
                     | 
| 621 | 
            -
             | 
| 622 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 623 |  | 
| 624 | 
             
                context = await _build_global_query_context(
         | 
| 625 | 
             
                    keywords,
         | 
| @@ -643,6 +658,9 @@ async def global_query( | |
| 643 | 
             
                    query,
         | 
| 644 | 
             
                    system_prompt=sys_prompt,
         | 
| 645 | 
             
                )
         | 
|  | |
|  | |
|  | |
| 646 | 
             
                return response
         | 
| 647 |  | 
| 648 | 
             
            async def _build_global_query_context(
         | 
| @@ -822,8 +840,8 @@ async def hybird_query( | |
| 822 |  | 
| 823 | 
             
                kw_prompt_temp = PROMPTS["keywords_extraction"]
         | 
| 824 | 
             
                kw_prompt = kw_prompt_temp.format(query=query)
         | 
|  | |
| 825 | 
             
                result = await use_model_func(kw_prompt)
         | 
| 826 | 
            -
             | 
| 827 | 
             
                try:
         | 
| 828 | 
             
                    keywords_data = json.loads(result)
         | 
| 829 | 
             
                    hl_keywords = keywords_data.get("high_level_keywords", [])
         | 
| @@ -831,10 +849,18 @@ async def hybird_query( | |
| 831 | 
             
                    hl_keywords = ', '.join(hl_keywords)
         | 
| 832 | 
             
                    ll_keywords = ', '.join(ll_keywords)
         | 
| 833 | 
             
                except json.JSONDecodeError as e:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 834 | 
             
                    # Handle parsing error
         | 
| 835 | 
            -
                     | 
| 836 | 
            -
             | 
| 837 | 
            -
             | 
|  | |
| 838 | 
             
                low_level_context = await _build_local_query_context(
         | 
| 839 | 
             
                    ll_keywords,
         | 
| 840 | 
             
                    knowledge_graph_inst,
         | 
| @@ -851,7 +877,7 @@ async def hybird_query( | |
| 851 | 
             
                    text_chunks_db,
         | 
| 852 | 
             
                    query_param,
         | 
| 853 | 
             
                )
         | 
| 854 | 
            -
             | 
| 855 | 
             
                context = combine_contexts(high_level_context, low_level_context)
         | 
| 856 |  | 
| 857 | 
             
                if query_param.only_need_context:
         | 
| @@ -867,10 +893,13 @@ async def hybird_query( | |
| 867 | 
             
                    query,
         | 
| 868 | 
             
                    system_prompt=sys_prompt,
         | 
| 869 | 
             
                )
         | 
|  | |
|  | |
| 870 | 
             
                return response
         | 
| 871 |  | 
| 872 | 
             
            def combine_contexts(high_level_context, low_level_context):
         | 
| 873 | 
             
                # Function to extract entities, relationships, and sources from context strings
         | 
|  | |
| 874 | 
             
                def extract_sections(context):
         | 
| 875 | 
             
                    entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
         | 
| 876 | 
             
                    relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
         | 
| @@ -883,8 +912,21 @@ def combine_contexts(high_level_context, low_level_context): | |
| 883 | 
             
                    return entities, relationships, sources
         | 
| 884 |  | 
| 885 | 
             
                # Extract sections from both contexts
         | 
| 886 | 
            -
             | 
| 887 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 888 |  | 
| 889 | 
             
                # Combine and deduplicate the entities
         | 
| 890 | 
             
                combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
         | 
| @@ -917,6 +959,7 @@ async def naive_query( | |
| 917 | 
             
                global_config: dict,
         | 
| 918 | 
             
            ):
         | 
| 919 | 
             
                use_model_func = global_config["llm_model_func"]
         | 
|  | |
| 920 | 
             
                results = await chunks_vdb.query(query, top_k=query_param.top_k)
         | 
| 921 | 
             
                if not len(results):
         | 
| 922 | 
             
                    return PROMPTS["fail_response"]
         | 
| @@ -939,6 +982,11 @@ async def naive_query( | |
| 939 | 
             
                response = await use_model_func(
         | 
| 940 | 
             
                    query,
         | 
| 941 | 
             
                    system_prompt=sys_prompt,
         | 
|  | |
| 942 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
| 943 | 
             
                return response
         | 
| 944 |  | 
|  | |
| 3 | 
             
            import re
         | 
| 4 | 
             
            from typing import Union
         | 
| 5 | 
             
            from collections import Counter, defaultdict
         | 
| 6 | 
            +
            import warnings
         | 
| 7 | 
             
            from .utils import (
         | 
| 8 | 
             
                logger,
         | 
| 9 | 
             
                clean_str,
         | 
|  | |
| 398 | 
             
                    keywords = keywords_data.get("low_level_keywords", [])
         | 
| 399 | 
             
                    keywords = ', '.join(keywords)
         | 
| 400 | 
             
                except json.JSONDecodeError as e:
         | 
| 401 | 
            +
                    try:
         | 
| 402 | 
            +
                        result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
         | 
| 403 | 
            +
                        keywords_data = json.loads(result)
         | 
| 404 | 
            +
                        keywords = keywords_data.get("low_level_keywords", [])
         | 
| 405 | 
            +
                        keywords = ', '.join(keywords)
         | 
| 406 | 
             
                    # Handle parsing error
         | 
| 407 | 
            +
                    except json.JSONDecodeError as e:
         | 
| 408 | 
            +
                        print(f"JSON parsing error: {e}")
         | 
| 409 | 
            +
                        return PROMPTS["fail_response"]
         | 
| 410 | 
             
                context = await _build_local_query_context(
         | 
| 411 | 
             
                    keywords,
         | 
| 412 | 
             
                    knowledge_graph_inst,
         | 
|  | |
| 426 | 
             
                    query,
         | 
| 427 | 
             
                    system_prompt=sys_prompt,
         | 
| 428 | 
             
                )
         | 
| 429 | 
            +
                if len(response)>len(sys_prompt):
         | 
| 430 | 
            +
                    response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
         | 
| 431 | 
            +
                
         | 
| 432 | 
             
                return response
         | 
| 433 |  | 
| 434 | 
             
            async def _build_local_query_context(
         | 
|  | |
| 625 | 
             
                    keywords = keywords_data.get("high_level_keywords", [])
         | 
| 626 | 
             
                    keywords = ', '.join(keywords)
         | 
| 627 | 
             
                except json.JSONDecodeError as e:
         | 
| 628 | 
            +
                    try:
         | 
| 629 | 
            +
                        result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
         | 
| 630 | 
            +
                        keywords_data = json.loads(result)
         | 
| 631 | 
            +
                        keywords = keywords_data.get("high_level_keywords", [])
         | 
| 632 | 
            +
                        keywords = ', '.join(keywords)
         | 
| 633 | 
            +
                
         | 
| 634 | 
            +
                    except json.JSONDecodeError as e:
         | 
| 635 | 
            +
                        # Handle parsing error
         | 
| 636 | 
            +
                        print(f"JSON parsing error: {e}")
         | 
| 637 | 
            +
                        return PROMPTS["fail_response"]
         | 
| 638 |  | 
| 639 | 
             
                context = await _build_global_query_context(
         | 
| 640 | 
             
                    keywords,
         | 
|  | |
| 658 | 
             
                    query,
         | 
| 659 | 
             
                    system_prompt=sys_prompt,
         | 
| 660 | 
             
                )
         | 
| 661 | 
            +
                if len(response)>len(sys_prompt):
         | 
| 662 | 
            +
                    response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
         | 
| 663 | 
            +
                
         | 
| 664 | 
             
                return response
         | 
| 665 |  | 
| 666 | 
             
            async def _build_global_query_context(
         | 
|  | |
| 840 |  | 
| 841 | 
             
                kw_prompt_temp = PROMPTS["keywords_extraction"]
         | 
| 842 | 
             
                kw_prompt = kw_prompt_temp.format(query=query)
         | 
| 843 | 
            +
                
         | 
| 844 | 
             
                result = await use_model_func(kw_prompt)
         | 
|  | |
| 845 | 
             
                try:
         | 
| 846 | 
             
                    keywords_data = json.loads(result)
         | 
| 847 | 
             
                    hl_keywords = keywords_data.get("high_level_keywords", [])
         | 
|  | |
| 849 | 
             
                    hl_keywords = ', '.join(hl_keywords)
         | 
| 850 | 
             
                    ll_keywords = ', '.join(ll_keywords)
         | 
| 851 | 
             
                except json.JSONDecodeError as e:
         | 
| 852 | 
            +
                    try:
         | 
| 853 | 
            +
                        result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
         | 
| 854 | 
            +
                        keywords_data = json.loads(result)
         | 
| 855 | 
            +
                        hl_keywords = keywords_data.get("high_level_keywords", [])
         | 
| 856 | 
            +
                        ll_keywords = keywords_data.get("low_level_keywords", [])
         | 
| 857 | 
            +
                        hl_keywords = ', '.join(hl_keywords)
         | 
| 858 | 
            +
                        ll_keywords = ', '.join(ll_keywords)
         | 
| 859 | 
             
                    # Handle parsing error
         | 
| 860 | 
            +
                    except json.JSONDecodeError as e:
         | 
| 861 | 
            +
                        print(f"JSON parsing error: {e}")
         | 
| 862 | 
            +
                        return PROMPTS["fail_response"]
         | 
| 863 | 
            +
             | 
| 864 | 
             
                low_level_context = await _build_local_query_context(
         | 
| 865 | 
             
                    ll_keywords,
         | 
| 866 | 
             
                    knowledge_graph_inst,
         | 
|  | |
| 877 | 
             
                    text_chunks_db,
         | 
| 878 | 
             
                    query_param,
         | 
| 879 | 
             
                )
         | 
| 880 | 
            +
             | 
| 881 | 
             
                context = combine_contexts(high_level_context, low_level_context)
         | 
| 882 |  | 
| 883 | 
             
                if query_param.only_need_context:
         | 
|  | |
| 893 | 
             
                    query,
         | 
| 894 | 
             
                    system_prompt=sys_prompt,
         | 
| 895 | 
             
                )
         | 
| 896 | 
            +
                if len(response)>len(sys_prompt):
         | 
| 897 | 
            +
                    response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
         | 
| 898 | 
             
                return response
         | 
| 899 |  | 
| 900 | 
             
            def combine_contexts(high_level_context, low_level_context):
         | 
| 901 | 
             
                # Function to extract entities, relationships, and sources from context strings
         | 
| 902 | 
            +
             | 
| 903 | 
             
                def extract_sections(context):
         | 
| 904 | 
             
                    entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
         | 
| 905 | 
             
                    relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
         | 
|  | |
| 912 | 
             
                    return entities, relationships, sources
         | 
| 913 |  | 
| 914 | 
             
                # Extract sections from both contexts
         | 
| 915 | 
            +
             | 
| 916 | 
            +
                if high_level_context==None:
         | 
| 917 | 
            +
                    warnings.warn("High Level context is None. Return empty High entity/relationship/source") 
         | 
| 918 | 
            +
                    hl_entities, hl_relationships, hl_sources = '','',''
         | 
| 919 | 
            +
                else:
         | 
| 920 | 
            +
                    hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
         | 
| 921 | 
            +
             | 
| 922 | 
            +
             | 
| 923 | 
            +
                if low_level_context==None:
         | 
| 924 | 
            +
                    warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
         | 
| 925 | 
            +
                    ll_entities, ll_relationships, ll_sources = '','',''
         | 
| 926 | 
            +
                else:
         | 
| 927 | 
            +
                    ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
         | 
| 928 | 
            +
             | 
| 929 | 
            +
             | 
| 930 |  | 
| 931 | 
             
                # Combine and deduplicate the entities
         | 
| 932 | 
             
                combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
         | 
|  | |
| 959 | 
             
                global_config: dict,
         | 
| 960 | 
             
            ):
         | 
| 961 | 
             
                use_model_func = global_config["llm_model_func"]
         | 
| 962 | 
            +
                use_model_name = global_config['llm_model_name']
         | 
| 963 | 
             
                results = await chunks_vdb.query(query, top_k=query_param.top_k)
         | 
| 964 | 
             
                if not len(results):
         | 
| 965 | 
             
                    return PROMPTS["fail_response"]
         | 
|  | |
| 982 | 
             
                response = await use_model_func(
         | 
| 983 | 
             
                    query,
         | 
| 984 | 
             
                    system_prompt=sys_prompt,
         | 
| 985 | 
            +
                    model_name = use_model_name
         | 
| 986 | 
             
                )
         | 
| 987 | 
            +
             | 
| 988 | 
            +
                if len(response)>len(sys_prompt):
         | 
| 989 | 
            +
                    response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
         | 
| 990 | 
            +
                
         | 
| 991 | 
             
                return response
         | 
| 992 |  | 
