|
|
|
|
|
""" |
|
Reference: |
|
- [graphrag](https://github.com/microsoft/graphrag) |
|
- [LightRag](https://github.com/HKUDS/LightRAG) |
|
""" |
|
|
|
import html |
|
import json |
|
import logging |
|
import re |
|
import time |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
from hashlib import md5 |
|
from typing import Any, Callable |
|
|
|
import networkx as nx |
|
import numpy as np |
|
import xxhash |
|
from networkx.readwrite import json_graph |
|
|
|
from api import settings |
|
from rag.nlp import search, rag_tokenizer |
|
from rag.utils.doc_store_conn import OrderByExpr |
|
from rag.utils.redis_conn import REDIS_CONN |
|
|
|
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] |
|
|
|
|
|
def perform_variable_replacements( |
|
input: str, history: list[dict] | None = None, variables: dict | None = None |
|
) -> str: |
|
"""Perform variable replacements on the input string and in a chat log.""" |
|
if history is None: |
|
history = [] |
|
if variables is None: |
|
variables = {} |
|
result = input |
|
|
|
def replace_all(input: str) -> str: |
|
result = input |
|
for k, v in variables.items(): |
|
result = result.replace(f"{{{k}}}", v) |
|
return result |
|
|
|
result = replace_all(result) |
|
for i, entry in enumerate(history): |
|
if entry.get("role") == "system": |
|
entry["content"] = replace_all(entry.get("content") or "") |
|
|
|
return result |
|
|
|
|
|
def clean_str(input: Any) -> str: |
|
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" |
|
|
|
if not isinstance(input, str): |
|
return input |
|
|
|
result = html.unescape(input.strip()) |
|
|
|
return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result) |
|
|
|
|
|
def dict_has_keys_with_types( |
|
data: dict, expected_fields: list[tuple[str, type]] |
|
) -> bool: |
|
"""Return True if the given dictionary has the given keys with the given types.""" |
|
for field, field_type in expected_fields: |
|
if field not in data: |
|
return False |
|
|
|
value = data[field] |
|
if not isinstance(value, field_type): |
|
return False |
|
return True |
|
|
|
|
|
def get_llm_cache(llmnm, txt, history, genconf): |
|
hasher = xxhash.xxh64() |
|
hasher.update(str(llmnm).encode("utf-8")) |
|
hasher.update(str(txt).encode("utf-8")) |
|
hasher.update(str(history).encode("utf-8")) |
|
hasher.update(str(genconf).encode("utf-8")) |
|
|
|
k = hasher.hexdigest() |
|
bin = REDIS_CONN.get(k) |
|
if not bin: |
|
return |
|
return bin |
|
|
|
|
|
def set_llm_cache(llmnm, txt, v, history, genconf): |
|
hasher = xxhash.xxh64() |
|
hasher.update(str(llmnm).encode("utf-8")) |
|
hasher.update(str(txt).encode("utf-8")) |
|
hasher.update(str(history).encode("utf-8")) |
|
hasher.update(str(genconf).encode("utf-8")) |
|
|
|
k = hasher.hexdigest() |
|
REDIS_CONN.set(k, v.encode("utf-8"), 24*3600) |
|
|
|
|
|
def get_embed_cache(llmnm, txt): |
|
hasher = xxhash.xxh64() |
|
hasher.update(str(llmnm).encode("utf-8")) |
|
hasher.update(str(txt).encode("utf-8")) |
|
|
|
k = hasher.hexdigest() |
|
bin = REDIS_CONN.get(k) |
|
if not bin: |
|
return |
|
return np.array(json.loads(bin)) |
|
|
|
|
|
def set_embed_cache(llmnm, txt, arr): |
|
hasher = xxhash.xxh64() |
|
hasher.update(str(llmnm).encode("utf-8")) |
|
hasher.update(str(txt).encode("utf-8")) |
|
|
|
k = hasher.hexdigest() |
|
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr) |
|
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600) |
|
|
|
|
|
def get_tags_from_cache(kb_ids): |
|
hasher = xxhash.xxh64() |
|
hasher.update(str(kb_ids).encode("utf-8")) |
|
|
|
k = hasher.hexdigest() |
|
bin = REDIS_CONN.get(k) |
|
if not bin: |
|
return |
|
return bin |
|
|
|
|
|
def set_tags_to_cache(kb_ids, tags): |
|
hasher = xxhash.xxh64() |
|
hasher.update(str(kb_ids).encode("utf-8")) |
|
|
|
k = hasher.hexdigest() |
|
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) |
|
|
|
|
|
def graph_merge(g1, g2): |
|
g = g2.copy() |
|
for n, attr in g1.nodes(data=True): |
|
if n not in g2.nodes(): |
|
g.add_node(n, **attr) |
|
continue |
|
|
|
for source, target, attr in g1.edges(data=True): |
|
if g.has_edge(source, target): |
|
g[source][target].update({"weight": attr.get("weight", 0)+1}) |
|
continue |
|
g.add_edge(source, target) |
|
|
|
for node_degree in g.degree: |
|
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) |
|
return g |
|
|
|
|
|
def compute_args_hash(*args): |
|
return md5(str(args).encode()).hexdigest() |
|
|
|
|
|
def handle_single_entity_extraction( |
|
record_attributes: list[str], |
|
chunk_key: str, |
|
): |
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"': |
|
return None |
|
|
|
entity_name = clean_str(record_attributes[1].upper()) |
|
if not entity_name.strip(): |
|
return None |
|
entity_type = clean_str(record_attributes[2].upper()) |
|
entity_description = clean_str(record_attributes[3]) |
|
entity_source_id = chunk_key |
|
return dict( |
|
entity_name=entity_name.upper(), |
|
entity_type=entity_type.upper(), |
|
description=entity_description, |
|
source_id=entity_source_id, |
|
) |
|
|
|
|
|
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str): |
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': |
|
return None |
|
|
|
source = clean_str(record_attributes[1].upper()) |
|
target = clean_str(record_attributes[2].upper()) |
|
edge_description = clean_str(record_attributes[3]) |
|
|
|
edge_keywords = clean_str(record_attributes[4]) |
|
edge_source_id = chunk_key |
|
weight = ( |
|
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 |
|
) |
|
pair = sorted([source.upper(), target.upper()]) |
|
return dict( |
|
src_id=pair[0], |
|
tgt_id=pair[1], |
|
weight=weight, |
|
description=edge_description, |
|
keywords=edge_keywords, |
|
source_id=edge_source_id, |
|
metadata={"created_at": time.time()}, |
|
) |
|
|
|
|
|
def pack_user_ass_to_openai_messages(*args: str): |
|
roles = ["user", "assistant"] |
|
return [ |
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args) |
|
] |
|
|
|
|
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: |
|
"""Split a string by multiple markers""" |
|
if not markers: |
|
return [content] |
|
results = re.split("|".join(re.escape(marker) for marker in markers), content) |
|
return [r.strip() for r in results if r.strip()] |
|
|
|
|
|
def is_float_regex(value): |
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) |
|
|
|
|
|
def chunk_id(chunk): |
|
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() |
|
|
|
|
|
def get_entity(tenant_id, kb_id, ent_name): |
|
conds = { |
|
"fields": ["content_with_weight"], |
|
"entity_kwd": ent_name, |
|
"size": 10000, |
|
"knowledge_graph_kwd": ["entity"] |
|
} |
|
res = [] |
|
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]) |
|
for id in es_res.ids: |
|
try: |
|
if isinstance(ent_name, str): |
|
return json.loads(es_res.field[id]["content_with_weight"]) |
|
res.append(json.loads(es_res.field[id]["content_with_weight"])) |
|
except Exception: |
|
continue |
|
|
|
return res |
|
|
|
|
|
def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta): |
|
chunk = { |
|
"important_kwd": [ent_name], |
|
"title_tks": rag_tokenizer.tokenize(ent_name), |
|
"entity_kwd": ent_name, |
|
"knowledge_graph_kwd": "entity", |
|
"entity_type_kwd": meta["entity_type"], |
|
"content_with_weight": json.dumps(meta, ensure_ascii=False), |
|
"content_ltks": rag_tokenizer.tokenize(meta["description"]), |
|
"source_id": list(set(meta["source_id"])), |
|
"kb_id": kb_id, |
|
"available_int": 0 |
|
} |
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) |
|
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []}, |
|
search.index_name(tenant_id), [kb_id]) |
|
if res.ids: |
|
settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id) |
|
else: |
|
ebd = get_embed_cache(embd_mdl.llm_name, ent_name) |
|
if ebd is None: |
|
try: |
|
ebd, _ = embd_mdl.encode([ent_name]) |
|
ebd = ebd[0] |
|
set_embed_cache(embd_mdl.llm_name, ent_name, ebd) |
|
except Exception as e: |
|
logging.exception(f"Fail to embed entity: {e}") |
|
if ebd is not None: |
|
chunk["q_%d_vec" % len(ebd)] = ebd |
|
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)) |
|
|
|
|
|
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): |
|
ents = from_ent_name |
|
if isinstance(ents, str): |
|
ents = [from_ent_name] |
|
if isinstance(to_ent_name, str): |
|
to_ent_name = [to_ent_name] |
|
ents.extend(to_ent_name) |
|
ents = list(set(ents)) |
|
conds = { |
|
"fields": ["content_with_weight"], |
|
"size": size, |
|
"from_entity_kwd": ents, |
|
"to_entity_kwd": ents, |
|
"knowledge_graph_kwd": ["relation"] |
|
} |
|
res = [] |
|
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) |
|
for id in es_res.ids: |
|
try: |
|
if size == 1: |
|
return json.loads(es_res.field[id]["content_with_weight"]) |
|
res.append(json.loads(es_res.field[id]["content_with_weight"])) |
|
except Exception: |
|
continue |
|
return res |
|
|
|
|
|
def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta): |
|
chunk = { |
|
"from_entity_kwd": from_ent_name, |
|
"to_entity_kwd": to_ent_name, |
|
"knowledge_graph_kwd": "relation", |
|
"content_with_weight": json.dumps(meta, ensure_ascii=False), |
|
"content_ltks": rag_tokenizer.tokenize(meta["description"]), |
|
"important_kwd": meta["keywords"], |
|
"source_id": list(set(meta["source_id"])), |
|
"weight_int": int(meta["weight"]), |
|
"kb_id": kb_id, |
|
"available_int": 0 |
|
} |
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) |
|
res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []}, |
|
search.index_name(tenant_id), [kb_id]) |
|
|
|
if res.ids: |
|
settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name}, |
|
chunk, |
|
search.index_name(tenant_id), kb_id) |
|
else: |
|
txt = f"{from_ent_name}->{to_ent_name}" |
|
ebd = get_embed_cache(embd_mdl.llm_name, txt) |
|
if ebd is None: |
|
try: |
|
ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"]) |
|
ebd = ebd[0] |
|
set_embed_cache(embd_mdl.llm_name, txt, ebd) |
|
except Exception as e: |
|
logging.exception(f"Fail to embed entity relation: {e}") |
|
if ebd is not None: |
|
chunk["q_%d_vec" % len(ebd)] = ebd |
|
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)) |
|
|
|
|
|
def get_graph(tenant_id, kb_id): |
|
conds = { |
|
"fields": ["content_with_weight", "source_id"], |
|
"removed_kwd": "N", |
|
"size": 1, |
|
"knowledge_graph_kwd": ["graph"] |
|
} |
|
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]) |
|
for id in res.ids: |
|
try: |
|
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ |
|
res.field[id]["source_id"] |
|
except Exception: |
|
continue |
|
return rebuild_graph(tenant_id, kb_id) |
|
|
|
|
|
def set_graph(tenant_id, kb_id, graph, docids): |
|
chunk = { |
|
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, |
|
indent=2), |
|
"knowledge_graph_kwd": "graph", |
|
"kb_id": kb_id, |
|
"source_id": list(docids), |
|
"available_int": 0, |
|
"removed_kwd": "N" |
|
} |
|
res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]) |
|
if res.ids: |
|
settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, |
|
search.index_name(tenant_id), kb_id) |
|
else: |
|
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)) |
|
|
|
|
|
def is_continuous_subsequence(subseq, seq): |
|
def find_all_indexes(tup, value): |
|
indexes = [] |
|
start = 0 |
|
while True: |
|
try: |
|
index = tup.index(value, start) |
|
indexes.append(index) |
|
start = index + 1 |
|
except ValueError: |
|
break |
|
return indexes |
|
|
|
index_list = find_all_indexes(seq,subseq[0]) |
|
for idx in index_list: |
|
if idx!=len(seq)-1: |
|
if seq[idx+1]==subseq[-1]: |
|
return True |
|
return False |
|
|
|
|
|
def merge_tuples(list1, list2): |
|
result = [] |
|
for tup in list1: |
|
last_element = tup[-1] |
|
if last_element in tup[:-1]: |
|
result.append(tup) |
|
else: |
|
matching_tuples = [t for t in list2 if t[0] == last_element] |
|
already_match_flag = 0 |
|
for match in matching_tuples: |
|
matchh = (match[1], match[0]) |
|
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup): |
|
continue |
|
already_match_flag = 1 |
|
merged_tuple = tup + match[1:] |
|
result.append(merged_tuple) |
|
if not already_match_flag: |
|
result.append(tup) |
|
return result |
|
|
|
|
|
def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): |
|
def n_neighbor(id): |
|
nonlocal graph, n_hop |
|
count = 0 |
|
source_edge = list(graph.edges(id)) |
|
if not source_edge: |
|
return [] |
|
count = count + 1 |
|
while count < n_hop: |
|
count = count + 1 |
|
sc_edge = deepcopy(source_edge) |
|
source_edge = [] |
|
for pair in sc_edge: |
|
append_edge = list(graph.edges(pair[-1])) |
|
for tuples in merge_tuples([pair], append_edge): |
|
source_edge.append(tuples) |
|
nbrs = [] |
|
for path in source_edge: |
|
n = {"path": path, "weights": []} |
|
wts = nx.get_edge_attributes(graph, 'weight') |
|
for i in range(len(path)-1): |
|
f, t = path[i], path[i+1] |
|
n["weights"].append(wts.get((f, t), 0)) |
|
nbrs.append(n) |
|
return nbrs |
|
|
|
pr = nx.pagerank(graph) |
|
for n, p in pr.items(): |
|
graph.nodes[n]["pagerank"] = p |
|
try: |
|
settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, |
|
{"rank_flt": p, |
|
"n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)}, |
|
search.index_name(tenant_id), kb_id) |
|
except Exception as e: |
|
logging.exception(e) |
|
|
|
ty2ents = defaultdict(list) |
|
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True): |
|
ty = graph.nodes[p].get("entity_type") |
|
if not ty or len(ty2ents[ty]) > 12: |
|
continue |
|
ty2ents[ty].append(p) |
|
|
|
chunk = { |
|
"content_with_weight": json.dumps(ty2ents, ensure_ascii=False), |
|
"kb_id": kb_id, |
|
"knowledge_graph_kwd": "ty2ents", |
|
"available_int": 0 |
|
} |
|
res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, |
|
search.index_name(tenant_id), [kb_id]) |
|
if res.ids: |
|
settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, |
|
chunk, |
|
search.index_name(tenant_id), kb_id) |
|
else: |
|
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) |
|
|
|
|
|
def get_entity_type2sampels(idxnms, kb_ids: list): |
|
es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, |
|
"size": 10000, |
|
"fields": ["content_with_weight"]}, |
|
idxnms, kb_ids) |
|
|
|
res = defaultdict(list) |
|
for id in es_res.ids: |
|
smp = es_res.field[id].get("content_with_weight") |
|
if not smp: |
|
continue |
|
try: |
|
smp = json.loads(smp) |
|
except Exception as e: |
|
logging.exception(e) |
|
|
|
for ty, ents in smp.items(): |
|
res[ty].extend(ents) |
|
return res |
|
|
|
|
|
def flat_uniq_list(arr, key): |
|
res = [] |
|
for a in arr: |
|
a = a[key] |
|
if isinstance(a, list): |
|
res.extend(a) |
|
else: |
|
res.append(a) |
|
return list(set(res)) |
|
|
|
|
|
def rebuild_graph(tenant_id, kb_id): |
|
graph = nx.Graph() |
|
src_ids = [] |
|
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] |
|
bs = 256 |
|
for i in range(0, 10000000, bs): |
|
es_res = settings.docStoreConn.search(flds, [], |
|
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, |
|
[], |
|
OrderByExpr(), |
|
i, bs, search.index_name(tenant_id), [kb_id] |
|
) |
|
tot = settings.docStoreConn.getTotal(es_res) |
|
if tot == 0: |
|
return None, None |
|
|
|
es_res = settings.docStoreConn.getFields(es_res, flds) |
|
for id, d in es_res.items(): |
|
src_ids.extend(d.get("source_id", [])) |
|
if d["knowledge_graph_kwd"] == "entity": |
|
graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"]) |
|
else: |
|
graph.add_edge( |
|
d["from_entity_kwd"], |
|
d["to_entity_kwd"], |
|
weight=int(d["weight_int"]) |
|
) |
|
|
|
if len(es_res.keys()) < 128: |
|
return graph, list(set(src_ids)) |
|
|
|
return graph, list(set(src_ids)) |
|
|