yangdx commited on
Commit
a944ef3
·
1 Parent(s): 48caec7

Improve graph query speed by batch operation

Browse files
Files changed (1) hide show
  1. lightrag/kg/postgres_impl.py +112 -64
lightrag/kg/postgres_impl.py CHANGED
@@ -1881,77 +1881,127 @@ class PGGraphStorage(BaseGraphStorage):
1881
 
1882
  result.is_truncated = False
1883
 
 
1884
  while queue:
1885
- # Dequeue the next node to process from the front of the queue
1886
- current_node, current_depth = queue.popleft()
1887
-
1888
- # Check one more depth for backward edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1889
  if current_depth > max_depth:
1890
  continue
1891
-
1892
- # Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency
1893
- current_entity_id = current_node.labels[0]
1894
- outgoing_query = """SELECT * FROM cypher('%s', $$
1895
- MATCH (a:base {entity_id: "%s"})-[r]->(b)
1896
- WITH r, b, id(r) as edge_id, id(b) as target_id
1897
- RETURN r, b, edge_id, target_id
1898
- $$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
1899
- self.graph_name,
1900
- current_entity_id,
1901
- )
1902
- incoming_query = """SELECT * FROM cypher('%s', $$
1903
- MATCH (a:base {entity_id: "%s"})<-[r]-(b)
1904
- WITH r, b, id(r) as edge_id, id(b) as target_id
1905
- RETURN r, b, edge_id, target_id
1906
- $$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
1907
- self.graph_name,
1908
- current_entity_id,
1909
- )
1910
-
1911
- outgoing_neighbors = await self._query(outgoing_query)
1912
- incoming_neighbors = await self._query(incoming_query)
1913
- neighbors = outgoing_neighbors + incoming_neighbors
1914
-
1915
- # logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})")
1916
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1917
  for record in neighbors:
1918
- if not record.get("b") or not record.get("r"):
1919
  continue
1920
-
1921
- b_node = record["b"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1922
  rel = record["r"]
1923
  edge_id = str(record["edge_id"])
1924
-
1925
- if (
1926
- "properties" not in b_node
1927
- or "entity_id" not in b_node["properties"]
1928
- ):
1929
- continue
1930
-
1931
- target_entity_id = b_node["properties"]["entity_id"]
1932
- target_internal_id = str(b_node["id"])
1933
-
1934
- # Create KnowledgeGraphNode for target
1935
- target_node = KnowledgeGraphNode(
1936
- id=target_internal_id,
1937
- labels=[target_entity_id],
1938
  properties=b_node["properties"],
1939
  )
1940
-
1941
  # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
1942
- sorted_pair = tuple(sorted([current_entity_id, target_entity_id]))
1943
-
1944
  # Create edge object
1945
  edge = KnowledgeGraphEdge(
1946
  id=edge_id,
1947
  type=rel["label"],
1948
- source=current_node.id,
1949
- target=target_internal_id,
1950
  properties=rel["properties"],
1951
  )
1952
-
1953
- if target_internal_id in visited_node_ids:
1954
- # Add backward edge if target node is visited
1955
  if (
1956
  edge_id not in visited_edges
1957
  and sorted_pair not in visited_edge_pairs
@@ -1959,17 +2009,16 @@ class PGGraphStorage(BaseGraphStorage):
1959
  result.edges.append(edge)
1960
  visited_edges.add(edge_id)
1961
  visited_edge_pairs.add(sorted_pair)
1962
-
1963
  else:
1964
  if len(visited_node_ids) < max_nodes and current_depth < max_depth:
1965
- # If target node not yet visited, add to result and queue
1966
- result.nodes.append(target_node)
1967
- visited_nodes.add(target_entity_id)
1968
- visited_node_ids.add(target_internal_id)
1969
-
1970
  # Add node to queue with incremented depth
1971
- queue.append((target_node, current_depth + 1))
1972
-
1973
  # Add forward edge
1974
  if (
1975
  edge_id not in visited_edges
@@ -1978,7 +2027,6 @@ class PGGraphStorage(BaseGraphStorage):
1978
  result.edges.append(edge)
1979
  visited_edges.add(edge_id)
1980
  visited_edge_pairs.add(sorted_pair)
1981
- # logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}")
1982
  else:
1983
  if current_depth < max_depth:
1984
  result.is_truncated = True
 
1881
 
1882
  result.is_truncated = False
1883
 
1884
+ # BFS search main loop
1885
  while queue:
1886
+ # Get all nodes at the current depth
1887
+ current_level_nodes = []
1888
+ current_depth = None
1889
+
1890
+ # Determine current depth
1891
+ if queue:
1892
+ current_depth = queue[0][1]
1893
+
1894
+ # Extract all nodes at current depth from the queue
1895
+ while queue and queue[0][1] == current_depth:
1896
+ node, depth = queue.popleft()
1897
+ if depth > max_depth:
1898
+ continue
1899
+ current_level_nodes.append(node)
1900
+
1901
+ if not current_level_nodes:
1902
+ continue
1903
+
1904
+ # Check depth limit
1905
  if current_depth > max_depth:
1906
  continue
1907
+
1908
+ # Prepare node IDs list
1909
+ node_ids = [node.labels[0] for node in current_level_nodes]
1910
+ formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
1911
+
1912
+ # Construct batch query for outgoing edges
1913
+ outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1914
+ UNWIND [{formatted_ids}] AS node_id
1915
+ MATCH (n:base {{entity_id: node_id}})
1916
+ OPTIONAL MATCH (n)-[r]->(neighbor:base)
1917
+ RETURN node_id AS current_id,
1918
+ id(n) AS current_internal_id,
1919
+ id(neighbor) AS neighbor_internal_id,
1920
+ neighbor.entity_id AS neighbor_id,
1921
+ id(r) AS edge_id,
1922
+ r,
1923
+ neighbor,
1924
+ true AS is_outgoing
1925
+ $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
1926
+ neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
1927
+
1928
+ # Construct batch query for incoming edges
1929
+ incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1930
+ UNWIND [{formatted_ids}] AS node_id
1931
+ MATCH (n:base {{entity_id: node_id}})
1932
+ OPTIONAL MATCH (n)<-[r]-(neighbor:base)
1933
+ RETURN node_id AS current_id,
1934
+ id(n) AS current_internal_id,
1935
+ id(neighbor) AS neighbor_internal_id,
1936
+ neighbor.entity_id AS neighbor_id,
1937
+ id(r) AS edge_id,
1938
+ r,
1939
+ neighbor,
1940
+ false AS is_outgoing
1941
+ $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
1942
+ neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
1943
+
1944
+ # Execute queries
1945
+ outgoing_results = await self._query(outgoing_query)
1946
+ incoming_results = await self._query(incoming_query)
1947
+
1948
+ # Combine results
1949
+ neighbors = outgoing_results + incoming_results
1950
+
1951
+ # Create mapping from node ID to node object
1952
+ node_map = {node.labels[0]: node for node in current_level_nodes}
1953
+
1954
+ # Process all results in a single loop
1955
  for record in neighbors:
1956
+ if not record.get("neighbor") or not record.get("r"):
1957
  continue
1958
+
1959
+ # Get current node information
1960
+ current_entity_id = record["current_id"]
1961
+ current_node = node_map[current_entity_id]
1962
+
1963
+ # Get neighbor node information
1964
+ neighbor_entity_id = record["neighbor_id"]
1965
+ neighbor_internal_id = str(record["neighbor_internal_id"])
1966
+ is_outgoing = record["is_outgoing"]
1967
+
1968
+ # Determine edge direction
1969
+ if is_outgoing:
1970
+ source_id = current_node.id
1971
+ target_id = neighbor_internal_id
1972
+ else:
1973
+ source_id = neighbor_internal_id
1974
+ target_id = current_node.id
1975
+
1976
+ if not neighbor_entity_id:
1977
+ continue
1978
+
1979
+ # Get edge and node information
1980
+ b_node = record["neighbor"]
1981
  rel = record["r"]
1982
  edge_id = str(record["edge_id"])
1983
+
1984
+ # Create neighbor node object
1985
+ neighbor_node = KnowledgeGraphNode(
1986
+ id=neighbor_internal_id,
1987
+ labels=[neighbor_entity_id],
 
 
 
 
 
 
 
 
 
1988
  properties=b_node["properties"],
1989
  )
1990
+
1991
  # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
1992
+ sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
1993
+
1994
  # Create edge object
1995
  edge = KnowledgeGraphEdge(
1996
  id=edge_id,
1997
  type=rel["label"],
1998
+ source=source_id,
1999
+ target=target_id,
2000
  properties=rel["properties"],
2001
  )
2002
+
2003
+ if neighbor_internal_id in visited_node_ids:
2004
+ # Add backward edge if neighbor node is already visited
2005
  if (
2006
  edge_id not in visited_edges
2007
  and sorted_pair not in visited_edge_pairs
 
2009
  result.edges.append(edge)
2010
  visited_edges.add(edge_id)
2011
  visited_edge_pairs.add(sorted_pair)
 
2012
  else:
2013
  if len(visited_node_ids) < max_nodes and current_depth < max_depth:
2014
+ # Add new node to result and queue
2015
+ result.nodes.append(neighbor_node)
2016
+ visited_nodes.add(neighbor_entity_id)
2017
+ visited_node_ids.add(neighbor_internal_id)
2018
+
2019
  # Add node to queue with incremented depth
2020
+ queue.append((neighbor_node, current_depth + 1))
2021
+
2022
  # Add forward edge
2023
  if (
2024
  edge_id not in visited_edges
 
2027
  result.edges.append(edge)
2028
  visited_edges.add(edge_id)
2029
  visited_edge_pairs.add(sorted_pair)
 
2030
  else:
2031
  if current_depth < max_depth:
2032
  result.is_truncated = True