yangdx
commited on
Commit
·
a944ef3
1
Parent(s):
48caec7
Improve graph query speed by batch operation
Browse files- lightrag/kg/postgres_impl.py +112 -64
lightrag/kg/postgres_impl.py
CHANGED
|
@@ -1881,77 +1881,127 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1881 |
|
| 1882 |
result.is_truncated = False
|
| 1883 |
|
|
|
|
| 1884 |
while queue:
|
| 1885 |
-
#
|
| 1886 |
-
|
| 1887 |
-
|
| 1888 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1889 |
if current_depth > max_depth:
|
| 1890 |
continue
|
| 1891 |
-
|
| 1892 |
-
#
|
| 1893 |
-
|
| 1894 |
-
|
| 1895 |
-
|
| 1896 |
-
|
| 1897 |
-
|
| 1898 |
-
|
| 1899 |
-
|
| 1900 |
-
|
| 1901 |
-
|
| 1902 |
-
|
| 1903 |
-
|
| 1904 |
-
|
| 1905 |
-
|
| 1906 |
-
|
| 1907 |
-
|
| 1908 |
-
|
| 1909 |
-
|
| 1910 |
-
|
| 1911 |
-
|
| 1912 |
-
|
| 1913 |
-
|
| 1914 |
-
|
| 1915 |
-
|
| 1916 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1917 |
for record in neighbors:
|
| 1918 |
-
if not record.get("
|
| 1919 |
continue
|
| 1920 |
-
|
| 1921 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1922 |
rel = record["r"]
|
| 1923 |
edge_id = str(record["edge_id"])
|
| 1924 |
-
|
| 1925 |
-
|
| 1926 |
-
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
continue
|
| 1930 |
-
|
| 1931 |
-
target_entity_id = b_node["properties"]["entity_id"]
|
| 1932 |
-
target_internal_id = str(b_node["id"])
|
| 1933 |
-
|
| 1934 |
-
# Create KnowledgeGraphNode for target
|
| 1935 |
-
target_node = KnowledgeGraphNode(
|
| 1936 |
-
id=target_internal_id,
|
| 1937 |
-
labels=[target_entity_id],
|
| 1938 |
properties=b_node["properties"],
|
| 1939 |
)
|
| 1940 |
-
|
| 1941 |
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
| 1942 |
-
sorted_pair = tuple(sorted([current_entity_id,
|
| 1943 |
-
|
| 1944 |
# Create edge object
|
| 1945 |
edge = KnowledgeGraphEdge(
|
| 1946 |
id=edge_id,
|
| 1947 |
type=rel["label"],
|
| 1948 |
-
source=
|
| 1949 |
-
target=
|
| 1950 |
properties=rel["properties"],
|
| 1951 |
)
|
| 1952 |
-
|
| 1953 |
-
if
|
| 1954 |
-
# Add backward edge if
|
| 1955 |
if (
|
| 1956 |
edge_id not in visited_edges
|
| 1957 |
and sorted_pair not in visited_edge_pairs
|
|
@@ -1959,17 +2009,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1959 |
result.edges.append(edge)
|
| 1960 |
visited_edges.add(edge_id)
|
| 1961 |
visited_edge_pairs.add(sorted_pair)
|
| 1962 |
-
|
| 1963 |
else:
|
| 1964 |
if len(visited_node_ids) < max_nodes and current_depth < max_depth:
|
| 1965 |
-
#
|
| 1966 |
-
result.nodes.append(
|
| 1967 |
-
visited_nodes.add(
|
| 1968 |
-
visited_node_ids.add(
|
| 1969 |
-
|
| 1970 |
# Add node to queue with incremented depth
|
| 1971 |
-
queue.append((
|
| 1972 |
-
|
| 1973 |
# Add forward edge
|
| 1974 |
if (
|
| 1975 |
edge_id not in visited_edges
|
|
@@ -1978,7 +2027,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1978 |
result.edges.append(edge)
|
| 1979 |
visited_edges.add(edge_id)
|
| 1980 |
visited_edge_pairs.add(sorted_pair)
|
| 1981 |
-
# logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}")
|
| 1982 |
else:
|
| 1983 |
if current_depth < max_depth:
|
| 1984 |
result.is_truncated = True
|
|
|
|
| 1881 |
|
| 1882 |
result.is_truncated = False
|
| 1883 |
|
| 1884 |
+
# BFS search main loop
|
| 1885 |
while queue:
|
| 1886 |
+
# Get all nodes at the current depth
|
| 1887 |
+
current_level_nodes = []
|
| 1888 |
+
current_depth = None
|
| 1889 |
+
|
| 1890 |
+
# Determine current depth
|
| 1891 |
+
if queue:
|
| 1892 |
+
current_depth = queue[0][1]
|
| 1893 |
+
|
| 1894 |
+
# Extract all nodes at current depth from the queue
|
| 1895 |
+
while queue and queue[0][1] == current_depth:
|
| 1896 |
+
node, depth = queue.popleft()
|
| 1897 |
+
if depth > max_depth:
|
| 1898 |
+
continue
|
| 1899 |
+
current_level_nodes.append(node)
|
| 1900 |
+
|
| 1901 |
+
if not current_level_nodes:
|
| 1902 |
+
continue
|
| 1903 |
+
|
| 1904 |
+
# Check depth limit
|
| 1905 |
if current_depth > max_depth:
|
| 1906 |
continue
|
| 1907 |
+
|
| 1908 |
+
# Prepare node IDs list
|
| 1909 |
+
node_ids = [node.labels[0] for node in current_level_nodes]
|
| 1910 |
+
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
|
| 1911 |
+
|
| 1912 |
+
# Construct batch query for outgoing edges
|
| 1913 |
+
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1914 |
+
UNWIND [{formatted_ids}] AS node_id
|
| 1915 |
+
MATCH (n:base {{entity_id: node_id}})
|
| 1916 |
+
OPTIONAL MATCH (n)-[r]->(neighbor:base)
|
| 1917 |
+
RETURN node_id AS current_id,
|
| 1918 |
+
id(n) AS current_internal_id,
|
| 1919 |
+
id(neighbor) AS neighbor_internal_id,
|
| 1920 |
+
neighbor.entity_id AS neighbor_id,
|
| 1921 |
+
id(r) AS edge_id,
|
| 1922 |
+
r,
|
| 1923 |
+
neighbor,
|
| 1924 |
+
true AS is_outgoing
|
| 1925 |
+
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
| 1926 |
+
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
| 1927 |
+
|
| 1928 |
+
# Construct batch query for incoming edges
|
| 1929 |
+
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1930 |
+
UNWIND [{formatted_ids}] AS node_id
|
| 1931 |
+
MATCH (n:base {{entity_id: node_id}})
|
| 1932 |
+
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
|
| 1933 |
+
RETURN node_id AS current_id,
|
| 1934 |
+
id(n) AS current_internal_id,
|
| 1935 |
+
id(neighbor) AS neighbor_internal_id,
|
| 1936 |
+
neighbor.entity_id AS neighbor_id,
|
| 1937 |
+
id(r) AS edge_id,
|
| 1938 |
+
r,
|
| 1939 |
+
neighbor,
|
| 1940 |
+
false AS is_outgoing
|
| 1941 |
+
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
| 1942 |
+
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
| 1943 |
+
|
| 1944 |
+
# Execute queries
|
| 1945 |
+
outgoing_results = await self._query(outgoing_query)
|
| 1946 |
+
incoming_results = await self._query(incoming_query)
|
| 1947 |
+
|
| 1948 |
+
# Combine results
|
| 1949 |
+
neighbors = outgoing_results + incoming_results
|
| 1950 |
+
|
| 1951 |
+
# Create mapping from node ID to node object
|
| 1952 |
+
node_map = {node.labels[0]: node for node in current_level_nodes}
|
| 1953 |
+
|
| 1954 |
+
# Process all results in a single loop
|
| 1955 |
for record in neighbors:
|
| 1956 |
+
if not record.get("neighbor") or not record.get("r"):
|
| 1957 |
continue
|
| 1958 |
+
|
| 1959 |
+
# Get current node information
|
| 1960 |
+
current_entity_id = record["current_id"]
|
| 1961 |
+
current_node = node_map[current_entity_id]
|
| 1962 |
+
|
| 1963 |
+
# Get neighbor node information
|
| 1964 |
+
neighbor_entity_id = record["neighbor_id"]
|
| 1965 |
+
neighbor_internal_id = str(record["neighbor_internal_id"])
|
| 1966 |
+
is_outgoing = record["is_outgoing"]
|
| 1967 |
+
|
| 1968 |
+
# Determine edge direction
|
| 1969 |
+
if is_outgoing:
|
| 1970 |
+
source_id = current_node.id
|
| 1971 |
+
target_id = neighbor_internal_id
|
| 1972 |
+
else:
|
| 1973 |
+
source_id = neighbor_internal_id
|
| 1974 |
+
target_id = current_node.id
|
| 1975 |
+
|
| 1976 |
+
if not neighbor_entity_id:
|
| 1977 |
+
continue
|
| 1978 |
+
|
| 1979 |
+
# Get edge and node information
|
| 1980 |
+
b_node = record["neighbor"]
|
| 1981 |
rel = record["r"]
|
| 1982 |
edge_id = str(record["edge_id"])
|
| 1983 |
+
|
| 1984 |
+
# Create neighbor node object
|
| 1985 |
+
neighbor_node = KnowledgeGraphNode(
|
| 1986 |
+
id=neighbor_internal_id,
|
| 1987 |
+
labels=[neighbor_entity_id],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1988 |
properties=b_node["properties"],
|
| 1989 |
)
|
| 1990 |
+
|
| 1991 |
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
| 1992 |
+
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
|
| 1993 |
+
|
| 1994 |
# Create edge object
|
| 1995 |
edge = KnowledgeGraphEdge(
|
| 1996 |
id=edge_id,
|
| 1997 |
type=rel["label"],
|
| 1998 |
+
source=source_id,
|
| 1999 |
+
target=target_id,
|
| 2000 |
properties=rel["properties"],
|
| 2001 |
)
|
| 2002 |
+
|
| 2003 |
+
if neighbor_internal_id in visited_node_ids:
|
| 2004 |
+
# Add backward edge if neighbor node is already visited
|
| 2005 |
if (
|
| 2006 |
edge_id not in visited_edges
|
| 2007 |
and sorted_pair not in visited_edge_pairs
|
|
|
|
| 2009 |
result.edges.append(edge)
|
| 2010 |
visited_edges.add(edge_id)
|
| 2011 |
visited_edge_pairs.add(sorted_pair)
|
|
|
|
| 2012 |
else:
|
| 2013 |
if len(visited_node_ids) < max_nodes and current_depth < max_depth:
|
| 2014 |
+
# Add new node to result and queue
|
| 2015 |
+
result.nodes.append(neighbor_node)
|
| 2016 |
+
visited_nodes.add(neighbor_entity_id)
|
| 2017 |
+
visited_node_ids.add(neighbor_internal_id)
|
| 2018 |
+
|
| 2019 |
# Add node to queue with incremented depth
|
| 2020 |
+
queue.append((neighbor_node, current_depth + 1))
|
| 2021 |
+
|
| 2022 |
# Add forward edge
|
| 2023 |
if (
|
| 2024 |
edge_id not in visited_edges
|
|
|
|
| 2027 |
result.edges.append(edge)
|
| 2028 |
visited_edges.add(edge_id)
|
| 2029 |
visited_edge_pairs.add(sorted_pair)
|
|
|
|
| 2030 |
else:
|
| 2031 |
if current_depth < max_depth:
|
| 2032 |
result.is_truncated = True
|