Add HF Support
Browse files- lightrag/lightrag.py +4 -3
- lightrag/llm.py +82 -2
- lightrag/operate.py +62 -14
lightrag/lightrag.py
CHANGED
|
@@ -5,7 +5,7 @@ from datetime import datetime
|
|
| 5 |
from functools import partial
|
| 6 |
from typing import Type, cast
|
| 7 |
|
| 8 |
-
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
|
| 9 |
from .operate import (
|
| 10 |
chunking_by_token_size,
|
| 11 |
extract_entities,
|
|
@@ -77,12 +77,13 @@ class LightRAG:
|
|
| 77 |
)
|
| 78 |
|
| 79 |
# text embedding
|
| 80 |
-
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding
|
| 81 |
embedding_batch_num: int = 32
|
| 82 |
embedding_func_max_async: int = 16
|
| 83 |
|
| 84 |
# LLM
|
| 85 |
-
llm_model_func: callable = gpt_4o_mini_complete
|
|
|
|
| 86 |
llm_model_max_token_size: int = 32768
|
| 87 |
llm_model_max_async: int = 16
|
| 88 |
|
|
|
|
| 5 |
from functools import partial
|
| 6 |
from typing import Type, cast
|
| 7 |
|
| 8 |
+
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model,hf_embedding
|
| 9 |
from .operate import (
|
| 10 |
chunking_by_token_size,
|
| 11 |
extract_entities,
|
|
|
|
| 77 |
)
|
| 78 |
|
| 79 |
# text embedding
|
| 80 |
+
embedding_func: EmbeddingFunc = field(default_factory=lambda: hf_embedding)#openai_embedding
|
| 81 |
embedding_batch_num: int = 32
|
| 82 |
embedding_func_max_async: int = 16
|
| 83 |
|
| 84 |
# LLM
|
| 85 |
+
llm_model_func: callable = hf_model#gpt_4o_mini_complete
|
| 86 |
+
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
| 87 |
llm_model_max_token_size: int = 32768
|
| 88 |
llm_model_max_async: int = 16
|
| 89 |
|
lightrag/llm.py
CHANGED
|
@@ -7,10 +7,12 @@ from tenacity import (
|
|
| 7 |
wait_exponential,
|
| 8 |
retry_if_exception_type,
|
| 9 |
)
|
| 10 |
-
|
|
|
|
| 11 |
from .base import BaseKVStorage
|
| 12 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 13 |
-
|
|
|
|
| 14 |
@retry(
|
| 15 |
stop=stop_after_attempt(3),
|
| 16 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
@@ -42,6 +44,52 @@ async def openai_complete_if_cache(
|
|
| 42 |
)
|
| 43 |
return response.choices[0].message.content
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
async def gpt_4o_complete(
|
| 46 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 47 |
) -> str:
|
|
@@ -65,6 +113,20 @@ async def gpt_4o_mini_complete(
|
|
| 65 |
**kwargs,
|
| 66 |
)
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
| 69 |
@retry(
|
| 70 |
stop=stop_after_attempt(3),
|
|
@@ -78,6 +140,24 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
|
|
| 78 |
)
|
| 79 |
return np.array([dp.embedding for dp in response.data])
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
if __name__ == "__main__":
|
| 82 |
import asyncio
|
| 83 |
|
|
|
|
| 7 |
wait_exponential,
|
| 8 |
retry_if_exception_type,
|
| 9 |
)
|
| 10 |
+
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
| 11 |
+
import torch
|
| 12 |
from .base import BaseKVStorage
|
| 13 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 14 |
+
import copy
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
@retry(
|
| 17 |
stop=stop_after_attempt(3),
|
| 18 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
|
| 44 |
)
|
| 45 |
return response.choices[0].message.content
|
| 46 |
|
| 47 |
+
async def hf_model_if_cache(
|
| 48 |
+
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 49 |
+
) -> str:
|
| 50 |
+
model_name = model
|
| 51 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
|
| 52 |
+
if hf_tokenizer.pad_token == None:
|
| 53 |
+
# print("use eos token")
|
| 54 |
+
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
| 55 |
+
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
|
| 56 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 57 |
+
messages = []
|
| 58 |
+
if system_prompt:
|
| 59 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 60 |
+
messages.extend(history_messages)
|
| 61 |
+
messages.append({"role": "user", "content": prompt})
|
| 62 |
+
|
| 63 |
+
if hashing_kv is not None:
|
| 64 |
+
args_hash = compute_args_hash(model, messages)
|
| 65 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 66 |
+
if if_cache_return is not None:
|
| 67 |
+
return if_cache_return["return"]
|
| 68 |
+
input_prompt = ''
|
| 69 |
+
try:
|
| 70 |
+
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 71 |
+
except:
|
| 72 |
+
try:
|
| 73 |
+
ori_message = copy.deepcopy(messages)
|
| 74 |
+
if messages[0]['role'] == "system":
|
| 75 |
+
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
|
| 76 |
+
messages = messages[1:]
|
| 77 |
+
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 78 |
+
except:
|
| 79 |
+
len_message = len(ori_message)
|
| 80 |
+
for msgid in range(len_message):
|
| 81 |
+
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
|
| 82 |
+
|
| 83 |
+
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
|
| 84 |
+
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
|
| 85 |
+
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
| 86 |
+
if hashing_kv is not None:
|
| 87 |
+
await hashing_kv.upsert(
|
| 88 |
+
{args_hash: {"return": response_text, "model": model}}
|
| 89 |
+
)
|
| 90 |
+
return response_text
|
| 91 |
+
|
| 92 |
+
|
| 93 |
async def gpt_4o_complete(
|
| 94 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 95 |
) -> str:
|
|
|
|
| 113 |
**kwargs,
|
| 114 |
)
|
| 115 |
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
async def hf_model(
|
| 119 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 120 |
+
) -> str:
|
| 121 |
+
input_string = kwargs.get('model_name', 'google/gemma-2-2b-it')
|
| 122 |
+
return await hf_model_if_cache(
|
| 123 |
+
input_string,
|
| 124 |
+
prompt,
|
| 125 |
+
system_prompt=system_prompt,
|
| 126 |
+
history_messages=history_messages,
|
| 127 |
+
**kwargs,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
| 131 |
@retry(
|
| 132 |
stop=stop_after_attempt(3),
|
|
|
|
| 140 |
)
|
| 141 |
return np.array([dp.embedding for dp in response.data])
|
| 142 |
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
global EMBED_MODEL
|
| 146 |
+
global tokenizer
|
| 147 |
+
EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
| 148 |
+
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
| 149 |
+
@wrap_embedding_func_with_attrs(
|
| 150 |
+
embedding_dim=384,
|
| 151 |
+
max_token_size=5000,
|
| 152 |
+
)
|
| 153 |
+
async def hf_embedding(texts: list[str]) -> np.ndarray:
|
| 154 |
+
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
outputs = EMBED_MODEL(input_ids)
|
| 157 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 158 |
+
return embeddings.detach().numpy()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
if __name__ == "__main__":
|
| 162 |
import asyncio
|
| 163 |
|
lightrag/operate.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
| 3 |
import re
|
| 4 |
from typing import Union
|
| 5 |
from collections import Counter, defaultdict
|
| 6 |
-
|
| 7 |
from .utils import (
|
| 8 |
logger,
|
| 9 |
clean_str,
|
|
@@ -398,10 +398,15 @@ async def local_query(
|
|
| 398 |
keywords = keywords_data.get("low_level_keywords", [])
|
| 399 |
keywords = ', '.join(keywords)
|
| 400 |
except json.JSONDecodeError as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
# Handle parsing error
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
context = await _build_local_query_context(
|
| 406 |
keywords,
|
| 407 |
knowledge_graph_inst,
|
|
@@ -421,6 +426,9 @@ async def local_query(
|
|
| 421 |
query,
|
| 422 |
system_prompt=sys_prompt,
|
| 423 |
)
|
|
|
|
|
|
|
|
|
|
| 424 |
return response
|
| 425 |
|
| 426 |
async def _build_local_query_context(
|
|
@@ -617,9 +625,16 @@ async def global_query(
|
|
| 617 |
keywords = keywords_data.get("high_level_keywords", [])
|
| 618 |
keywords = ', '.join(keywords)
|
| 619 |
except json.JSONDecodeError as e:
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
context = await _build_global_query_context(
|
| 625 |
keywords,
|
|
@@ -643,6 +658,9 @@ async def global_query(
|
|
| 643 |
query,
|
| 644 |
system_prompt=sys_prompt,
|
| 645 |
)
|
|
|
|
|
|
|
|
|
|
| 646 |
return response
|
| 647 |
|
| 648 |
async def _build_global_query_context(
|
|
@@ -822,8 +840,8 @@ async def hybird_query(
|
|
| 822 |
|
| 823 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 824 |
kw_prompt = kw_prompt_temp.format(query=query)
|
|
|
|
| 825 |
result = await use_model_func(kw_prompt)
|
| 826 |
-
|
| 827 |
try:
|
| 828 |
keywords_data = json.loads(result)
|
| 829 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
@@ -831,10 +849,18 @@ async def hybird_query(
|
|
| 831 |
hl_keywords = ', '.join(hl_keywords)
|
| 832 |
ll_keywords = ', '.join(ll_keywords)
|
| 833 |
except json.JSONDecodeError as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
# Handle parsing error
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
|
|
|
| 838 |
low_level_context = await _build_local_query_context(
|
| 839 |
ll_keywords,
|
| 840 |
knowledge_graph_inst,
|
|
@@ -851,7 +877,7 @@ async def hybird_query(
|
|
| 851 |
text_chunks_db,
|
| 852 |
query_param,
|
| 853 |
)
|
| 854 |
-
|
| 855 |
context = combine_contexts(high_level_context, low_level_context)
|
| 856 |
|
| 857 |
if query_param.only_need_context:
|
|
@@ -867,10 +893,13 @@ async def hybird_query(
|
|
| 867 |
query,
|
| 868 |
system_prompt=sys_prompt,
|
| 869 |
)
|
|
|
|
|
|
|
| 870 |
return response
|
| 871 |
|
| 872 |
def combine_contexts(high_level_context, low_level_context):
|
| 873 |
# Function to extract entities, relationships, and sources from context strings
|
|
|
|
| 874 |
def extract_sections(context):
|
| 875 |
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
| 876 |
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
|
@@ -883,8 +912,21 @@ def combine_contexts(high_level_context, low_level_context):
|
|
| 883 |
return entities, relationships, sources
|
| 884 |
|
| 885 |
# Extract sections from both contexts
|
| 886 |
-
|
| 887 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
|
| 889 |
# Combine and deduplicate the entities
|
| 890 |
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
|
@@ -917,6 +959,7 @@ async def naive_query(
|
|
| 917 |
global_config: dict,
|
| 918 |
):
|
| 919 |
use_model_func = global_config["llm_model_func"]
|
|
|
|
| 920 |
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
| 921 |
if not len(results):
|
| 922 |
return PROMPTS["fail_response"]
|
|
@@ -939,6 +982,11 @@ async def naive_query(
|
|
| 939 |
response = await use_model_func(
|
| 940 |
query,
|
| 941 |
system_prompt=sys_prompt,
|
|
|
|
| 942 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
return response
|
| 944 |
|
|
|
|
| 3 |
import re
|
| 4 |
from typing import Union
|
| 5 |
from collections import Counter, defaultdict
|
| 6 |
+
import warnings
|
| 7 |
from .utils import (
|
| 8 |
logger,
|
| 9 |
clean_str,
|
|
|
|
| 398 |
keywords = keywords_data.get("low_level_keywords", [])
|
| 399 |
keywords = ', '.join(keywords)
|
| 400 |
except json.JSONDecodeError as e:
|
| 401 |
+
try:
|
| 402 |
+
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
| 403 |
+
keywords_data = json.loads(result)
|
| 404 |
+
keywords = keywords_data.get("low_level_keywords", [])
|
| 405 |
+
keywords = ', '.join(keywords)
|
| 406 |
# Handle parsing error
|
| 407 |
+
except json.JSONDecodeError as e:
|
| 408 |
+
print(f"JSON parsing error: {e}")
|
| 409 |
+
return PROMPTS["fail_response"]
|
| 410 |
context = await _build_local_query_context(
|
| 411 |
keywords,
|
| 412 |
knowledge_graph_inst,
|
|
|
|
| 426 |
query,
|
| 427 |
system_prompt=sys_prompt,
|
| 428 |
)
|
| 429 |
+
if len(response)>len(sys_prompt):
|
| 430 |
+
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
| 431 |
+
|
| 432 |
return response
|
| 433 |
|
| 434 |
async def _build_local_query_context(
|
|
|
|
| 625 |
keywords = keywords_data.get("high_level_keywords", [])
|
| 626 |
keywords = ', '.join(keywords)
|
| 627 |
except json.JSONDecodeError as e:
|
| 628 |
+
try:
|
| 629 |
+
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
| 630 |
+
keywords_data = json.loads(result)
|
| 631 |
+
keywords = keywords_data.get("high_level_keywords", [])
|
| 632 |
+
keywords = ', '.join(keywords)
|
| 633 |
+
|
| 634 |
+
except json.JSONDecodeError as e:
|
| 635 |
+
# Handle parsing error
|
| 636 |
+
print(f"JSON parsing error: {e}")
|
| 637 |
+
return PROMPTS["fail_response"]
|
| 638 |
|
| 639 |
context = await _build_global_query_context(
|
| 640 |
keywords,
|
|
|
|
| 658 |
query,
|
| 659 |
system_prompt=sys_prompt,
|
| 660 |
)
|
| 661 |
+
if len(response)>len(sys_prompt):
|
| 662 |
+
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
| 663 |
+
|
| 664 |
return response
|
| 665 |
|
| 666 |
async def _build_global_query_context(
|
|
|
|
| 840 |
|
| 841 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 842 |
kw_prompt = kw_prompt_temp.format(query=query)
|
| 843 |
+
|
| 844 |
result = await use_model_func(kw_prompt)
|
|
|
|
| 845 |
try:
|
| 846 |
keywords_data = json.loads(result)
|
| 847 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
|
|
| 849 |
hl_keywords = ', '.join(hl_keywords)
|
| 850 |
ll_keywords = ', '.join(ll_keywords)
|
| 851 |
except json.JSONDecodeError as e:
|
| 852 |
+
try:
|
| 853 |
+
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
| 854 |
+
keywords_data = json.loads(result)
|
| 855 |
+
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 856 |
+
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 857 |
+
hl_keywords = ', '.join(hl_keywords)
|
| 858 |
+
ll_keywords = ', '.join(ll_keywords)
|
| 859 |
# Handle parsing error
|
| 860 |
+
except json.JSONDecodeError as e:
|
| 861 |
+
print(f"JSON parsing error: {e}")
|
| 862 |
+
return PROMPTS["fail_response"]
|
| 863 |
+
|
| 864 |
low_level_context = await _build_local_query_context(
|
| 865 |
ll_keywords,
|
| 866 |
knowledge_graph_inst,
|
|
|
|
| 877 |
text_chunks_db,
|
| 878 |
query_param,
|
| 879 |
)
|
| 880 |
+
|
| 881 |
context = combine_contexts(high_level_context, low_level_context)
|
| 882 |
|
| 883 |
if query_param.only_need_context:
|
|
|
|
| 893 |
query,
|
| 894 |
system_prompt=sys_prompt,
|
| 895 |
)
|
| 896 |
+
if len(response)>len(sys_prompt):
|
| 897 |
+
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
| 898 |
return response
|
| 899 |
|
| 900 |
def combine_contexts(high_level_context, low_level_context):
|
| 901 |
# Function to extract entities, relationships, and sources from context strings
|
| 902 |
+
|
| 903 |
def extract_sections(context):
|
| 904 |
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
| 905 |
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
|
|
|
| 912 |
return entities, relationships, sources
|
| 913 |
|
| 914 |
# Extract sections from both contexts
|
| 915 |
+
|
| 916 |
+
if high_level_context==None:
|
| 917 |
+
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
| 918 |
+
hl_entities, hl_relationships, hl_sources = '','',''
|
| 919 |
+
else:
|
| 920 |
+
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
if low_level_context==None:
|
| 924 |
+
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
| 925 |
+
ll_entities, ll_relationships, ll_sources = '','',''
|
| 926 |
+
else:
|
| 927 |
+
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
| 928 |
+
|
| 929 |
+
|
| 930 |
|
| 931 |
# Combine and deduplicate the entities
|
| 932 |
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
|
|
|
| 959 |
global_config: dict,
|
| 960 |
):
|
| 961 |
use_model_func = global_config["llm_model_func"]
|
| 962 |
+
use_model_name = global_config['llm_model_name']
|
| 963 |
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
| 964 |
if not len(results):
|
| 965 |
return PROMPTS["fail_response"]
|
|
|
|
| 982 |
response = await use_model_func(
|
| 983 |
query,
|
| 984 |
system_prompt=sys_prompt,
|
| 985 |
+
model_name = use_model_name
|
| 986 |
)
|
| 987 |
+
|
| 988 |
+
if len(response)>len(sys_prompt):
|
| 989 |
+
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
| 990 |
+
|
| 991 |
return response
|
| 992 |
|