lailaelkoussy's picture
upload repoknowledgegraphlib
a100cc5 verified
import networkx as nx
import json
import os
import asyncio
import nest_asyncio
import tqdm
# from pathlib import Path
import os.path
import tempfile
import subprocess
from typing import List, Optional, Dict
import logging
import urllib.parse
from .ModelService import create_model_service
from .Node import Node, DirectoryNode, FileNode, ChunkNode, EntityNode
from .CodeParser import CodeParser
from .EntityExtractor import HybridEntityExtractor
from .CodeIndex import CodeIndex
from .utils.logger_utils import setup_logger
from .utils.parsing_utils import read_directory_files_recursively, get_language_from_filename
from .utils.path_utils import prepare_input_path, build_entity_alias_map, resolve_entity_call
from .EntityChunkMapper import EntityChunkMapper
LOGGER_NAME = 'REPO_KNOWLEDGE_GRAPH_LOGGER'
MODEL_SERVICE_TYPES = ['openai', 'sentence-transformers']
# A RepoKnowledgeGraph is a weighted DAG based on a tree-structure with added edges
class RepoKnowledgeGraph:
"""
RepoKnowledgeGraph builds a knowledge graph of a code repository.
It parses source files, extracts code entities and relationships, and organizes them
into a directed acyclic graph (DAG) with additional semantic edges.
Use `from_path()` or `load_graph_from_file()` to create instances.
"""
def __init__(self):
"""
Private constructor. Use from_path() or load_graph_from_file() instead.
"""
raise RuntimeError(
"Cannot instantiate RepoKnowledgeGraph directly. "
"Use RepoKnowledgeGraph.from_path() or RepoKnowledgeGraph.load_graph_from_file() instead."
)
def _initialize(self, model_service_kwargs: dict, code_index_kwargs: Optional[dict] = None):
"""Internal initialization method."""
setup_logger(LOGGER_NAME)
self.logger = logging.getLogger(LOGGER_NAME)
self.logger.info('Initializing RepoKnowledgeGraph instance.')
self.code_parser = CodeParser()
# Determine if we should skip loading the embedder based on index_type
index_type = (code_index_kwargs or {}).get('index_type', 'hybrid')
skip_embedder = index_type == 'keyword-only'
if skip_embedder:
self.logger.info('Using keyword-only index, skipping embedder initialization')
self.model_service = create_model_service(skip_embedder=skip_embedder, **model_service_kwargs)
self.entities = {}
self.graph = nx.DiGraph()
self.knowledge_graph = nx.DiGraph()
self.code_index = None
self.entity_extractor = HybridEntityExtractor()
def __iter__(self):
# Yield only the 'data' attribute from each node
return (node_data['data'] for _, node_data in self.graph.nodes(data=True))
def __getitem__(self, node_id):
return self.graph.nodes[node_id]['data']
@classmethod
def from_path(cls, path: str, skip_dirs: Optional[list] = None, index_nodes: bool = True, describe_nodes=False,
extract_entities: bool = False, model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
if skip_dirs is None:
skip_dirs = []
if model_service_kwargs is None:
model_service_kwargs = {}
"""
Alternative constructor to build a RepoKnowledgeGraph from a path, with options to skip directories
and control entity extraction and node description.
Args:
path (str): Path to the root of the code repository.
skip_dirs (list): List of directory names to skip.
index_nodes (bool): Whether to build a code index.
describe_nodes (bool): Whether to generate descriptions for code chunks.
extract_entities (bool): Whether to extract entities from code.
Returns:
RepoKnowledgeGraph: The constructed knowledge graph.
"""
instance = cls.__new__(cls) # Create instance without calling __init__
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
instance.logger.info(f"Preparing to build knowledge graph from path: {path}")
prepared_path = prepare_input_path(path)
instance.logger.debug(f"Prepared input path: {prepared_path}")
# Handle running event loop (e.g., in Jupyter)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
instance.logger.debug("Detected running event loop, applying nest_asyncio.")
nest_asyncio.apply()
task = instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes,
describe_nodes=describe_nodes, extract_entities=extract_entities)
loop.run_until_complete(task)
else:
instance.logger.debug("No running event loop, using asyncio.run.")
asyncio.run(instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes,
describe_nodes=describe_nodes,
extract_entities=extract_entities))
instance.logger.info("Parsing files and building initial nodes...")
instance.logger.info("Initial parse and node creation complete. Building relationships between nodes...")
instance._build_relationships()
if index_nodes:
instance.logger.info("Building code index for all nodes in the graph...")
instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **(code_index_kwargs or {}))
instance.logger.info("Knowledge graph construction from path completed successfully.")
return instance
@classmethod
def from_repo(
cls,
repo_url: str,
skip_dirs: Optional[list] = None,
index_nodes: bool = True,
describe_nodes: bool = False,
extract_entities: bool = False,
model_service_kwargs: Optional[dict] = None,
code_index_kwargs: Optional[dict]=None,
github_token: Optional[str] = None,
allow_unauthenticated_clone: bool = True,
):
"""
Alternative constructor to build a RepoKnowledgeGraph from a remote git repository URL.
Args:
repo_url (str): Git repository URL (SSH or HTTPS).
skip_dirs (list): List of directory names to skip.
index_nodes (bool): Whether to build a code index.
describe_nodes (bool): Whether to generate descriptions for code chunks.
extract_entities (bool): Whether to extract entities from code.
github_token (str, optional): Personal access token to access private GitHub repos.
If not provided, the method will look for the `GITHUB_OAUTH_TOKEN` environment variable.
allow_unauthenticated_clone (bool): If True, attempt to clone without a token when none is provided.
If False, raise an error when no token is available.
Returns:
RepoKnowledgeGraph: The constructed knowledge graph.
"""
if skip_dirs is None:
skip_dirs = []
if model_service_kwargs is None:
model_service_kwargs = {}
instance = cls.__new__(cls)
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
instance.logger.info(f"Starting knowledge graph build from remote repository: {repo_url}")
# Determine token
token = github_token or os.environ.get('GITHUB_OAUTH_TOKEN')
with tempfile.TemporaryDirectory() as tmpdirname:
clone_url = repo_url
try:
if repo_url.startswith('git@'):
# Convert git@github.com:owner/repo.git -> https://github.com/owner/repo.git
clone_url = repo_url.replace(':', '/').split('git@')[-1]
clone_url = f'https://{clone_url}'
if token and clone_url.startswith('https://'):
encoded_token = urllib.parse.quote(token, safe='')
clone_url = clone_url.replace('https://', f'https://{encoded_token}@')
elif not token and not allow_unauthenticated_clone:
raise ValueError(
"GitHub token not provided and unauthenticated clone is disabled. "
"Set allow_unauthenticated_clone=True or provide a token."
)
instance.logger.debug(f"Running git clone: {clone_url} -> {tmpdirname}")
subprocess.run(['git', 'clone', clone_url, tmpdirname], check=True)
except Exception as e:
instance.logger.error(f"Failed to clone repository {repo_url} using URL {clone_url}: {e}")
raise
instance.logger.info(f"Repository successfully cloned to: {tmpdirname}")
return cls.from_path(
tmpdirname,
skip_dirs=skip_dirs,
index_nodes=index_nodes,
describe_nodes=describe_nodes,
extract_entities=extract_entities,
model_service_kwargs=model_service_kwargs,
code_index_kwargs=code_index_kwargs
)
async def _initial_parse_path_async(self, path: str, skip_dirs: list, index_nodes=True, describe_nodes=True,
extract_entities: bool = True):
self.logger.info(f"Beginning async parsing of repository at path: {path}")
"""
Orchestrates the parsing and graph construction process:
1. Reads files and splits into chunks.
2. Extracts entities and relationships.
3. Builds chunk, file, directory, and root nodes.
4. Aggregates entity information.
Args:
path (str): Root path to parse.
skip_dirs (list): Directories to skip.
index_nodes (bool): Whether to build code index.
describe_nodes (bool): Whether to generate descriptions.
extract_entities (bool): Whether to extract entities.
"""
# --- Pass 1: Create ChunkNodes ---
level1_node_contents = read_directory_files_recursively(
path, skip_dirs=skip_dirs,
skip_pattern=r"(?:\.log$|\.json$|(?:^|/)(?:\.git|\.idea|__pycache__|\.cache)(?:/|$)|(?:^|/)(?:changelog|ChangeLog)(?:\.[a-z0-9]+)?$|\.cache$)"
)
self.logger.debug(f"Found {len(level1_node_contents)} files to process.")
self.logger.info("Chunk nodes creation step started.")
chunk_info = await self._create_chunk_nodes(
level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=path
)
self.logger.info("Chunk nodes creation step finished.")
self.logger.info("File nodes creation step started.")
file_info = self._create_file_nodes(
chunk_info, level1_node_contents
)
self.logger.info("File nodes creation step finished.")
self.logger.info("Directory nodes creation step started.")
dir_agg = self._create_directory_nodes(
file_info
)
self.logger.info("Directory nodes creation step finished.")
self.logger.info("Aggregating all nodes to root node.")
self._aggregate_to_root(dir_agg)
self.logger.info("Async parse and node aggregation fully complete.")
async def _create_chunk_nodes(self, level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=None):
self.logger.info(f"Starting chunk node creation for {len(level1_node_contents)} files.")
accepted_extensions = {'.py', '.c', '.cpp', '.h', '.hpp', '.java', '.js', '.ts', '.jsx', '.tsx', '.rs', '.html'}
chunk_info = {}
entity_mapper = EntityChunkMapper()
total_chunks = 0
# Use tqdm for progress bar over files
for file_path in tqdm.tqdm(level1_node_contents, desc="Processing files for chunk nodes"):
self.logger.debug(f"Processing file for chunk nodes: {file_path}")
full_path = os.path.normpath(file_path)
parts = full_path.split(os.sep)
_, ext = os.path.splitext(file_path)
is_code_file = ext.lower() in accepted_extensions
self.logger.debug(f"Parsing file: {file_path}")
# Parse file into chunks
parsed_content = self.code_parser.parse(file_name=file_path, file_content=level1_node_contents[file_path])
self.logger.debug(f"Parsed {len(parsed_content)} chunks from file: {file_path}")
total_chunks += len(parsed_content)
# Entity extraction logging
if extract_entities and is_code_file:
self.logger.debug(f"Extracting entities from code file: {file_path}")
try:
# Construct full path for entity extraction (needed for C/C++ include resolution)
extraction_file_path = os.path.join(root_path, file_path) if root_path else file_path
file_declared_entities, file_called_entities = self.entity_extractor.extract_entities(
code=level1_node_contents[file_path], file_name=extraction_file_path)
self.logger.debug(f"Extracted {len(file_declared_entities)} declared and {len(file_called_entities)} called entities from file: {file_path}")
chunk_declared_map, chunk_called_map = entity_mapper.map_entities_to_chunks(
file_declared_entities, file_called_entities, parsed_content, file_name=file_path)
self.logger.debug(f"Mapped entities to {len(parsed_content)} chunks for file: {file_path}")
except Exception as e:
self.logger.error(f"Error extracting entities from {file_path}: {e}")
file_declared_entities, file_called_entities = [], []
chunk_declared_map = {i: [] for i in range(len(parsed_content))}
chunk_called_map = {i: [] for i in range(len(parsed_content))}
else:
self.logger.debug(f"Skipping entity extraction for non-code file: {file_path}")
file_declared_entities, file_called_entities = [], []
chunk_declared_map = {i: [] for i in range(len(parsed_content))}
chunk_called_map = {i: [] for i in range(len(parsed_content))}
chunk_tasks = []
for i, chunk in enumerate(parsed_content):
chunk_id = f'{file_path}_{i}'
self.logger.debug(f"Scheduling processing for chunk {chunk_id} of file {file_path}")
async def process_chunk(i=i, chunk=chunk, chunk_id=chunk_id):
self.logger.debug(f"Creating chunk node: {chunk_id}")
declared_entities = chunk_declared_map.get(i, [])
called_entities = chunk_called_map.get(i, [])
# FIRST PASS: Register all declared entities with aliases
# Build temporary alias map for checking existing entities
temp_alias_map = build_entity_alias_map(self.entities)
for entity in declared_entities:
name = entity.get("name")
if not name:
continue
# Check if this entity already exists under any of its aliases
entity_aliases = entity.get("aliases", [])
canonical_name = None
# First check if the name itself already exists or is an alias
if name in temp_alias_map:
canonical_name = temp_alias_map[name]
self.logger.debug(f"Entity '{name}' already exists as '{canonical_name}'")
else:
# Check if any of the entity's aliases match existing entities
for alias in entity_aliases:
if alias in temp_alias_map:
canonical_name = temp_alias_map[alias]
self.logger.debug(f"Entity '{name}' matches existing entity '{canonical_name}' via alias '{alias}'")
break
# If we found a match, use the canonical name; otherwise use the entity name
if canonical_name:
entity_key = canonical_name
else:
entity_key = name
self.logger.debug(f"Registering new declared entity '{name}' in chunk {chunk_id}")
self.entities[entity_key] = {
"declaring_chunk_ids": [],
"calling_chunk_ids": [],
"type": [],
"dtype": None,
"aliases": []
}
# Update temp alias map with new entity
temp_alias_map[entity_key] = entity_key
if chunk_id not in self.entities[entity_key]["declaring_chunk_ids"]:
self.entities[entity_key]["declaring_chunk_ids"].append(chunk_id)
entity_type = entity.get("type")
if entity_type and entity_type not in self.entities[entity_key]["type"]:
self.entities[entity_key]["type"].append(entity_type)
dtype = entity.get("dtype")
if dtype:
self.entities[entity_key]["dtype"] = dtype
# Store aliases (add new ones, avoiding duplicates)
for alias in [name] + entity_aliases:
if alias and alias not in self.entities[entity_key]["aliases"]:
self.entities[entity_key]["aliases"].append(alias)
temp_alias_map[alias] = entity_key # Update temp map
self.logger.debug(f"Declared entity '{name}' registered as '{entity_key}' in chunk {chunk_id} with aliases: {self.entities[entity_key]['aliases']}")
# Logging for node creation
if describe_nodes:
self.logger.info(f"Generating description for chunk {chunk_id}")
try:
description = await self.model_service.query_async(
f'Summarize this {get_language_from_filename(file_path)} code chunk in a few sentences: {chunk}')
except Exception as e:
self.logger.error(f"Error generating description for chunk {chunk_id}: {e}")
description = ''
else:
self.logger.debug(f"No description requested for chunk {chunk_id}")
description = ''
chunk_node = ChunkNode(
id=chunk_id,
name=chunk_id,
path=file_path,
content=chunk,
order_in_file=i,
called_entities=called_entities,
declared_entities=declared_entities,
language=get_language_from_filename(file_path),
description=description,
)
self.logger.debug(f"Chunk node created: {chunk_id}")
# NOTE: Embeddings are now deferred to CodeIndex for efficient batch processing
# This avoids the slow one-at-a-time embedding during chunk creation
chunk_node.embedding = None
return (chunk_id, chunk_node, declared_entities, called_entities)
chunk_tasks.append(process_chunk())
chunk_results = await asyncio.gather(*chunk_tasks)
self.logger.debug(f"Finished processing {len(chunk_results)} chunks for file {file_path}.")
chunk_info[file_path] = {
'chunk_results': chunk_results,
'file_declared_entities': file_declared_entities,
'file_called_entities': file_called_entities
}
# Log summary
self.logger.info(f"Created {total_chunks} chunk nodes from {len(level1_node_contents)} files")
# SECOND PASS: Now that all declared entities are registered, resolve called entities
self.logger.info("Starting second pass: resolving called entities using alias map...")
alias_map = build_entity_alias_map(self.entities)
self.logger.info(f"Built alias map with {len(alias_map)} entries for resolution")
resolved_count = 0
for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Resolving called entities"):
chunk_results = file_data['chunk_results']
for chunk_id, chunk_node, declared_entities, called_entities in chunk_results:
for called_name in called_entities:
# Skip empty or whitespace-only names
if not called_name or not called_name.strip():
continue
# Try to resolve this called entity to an existing declared entity using aliases
resolved_name = resolve_entity_call(called_name, alias_map)
# Use the resolved name if found, otherwise check if called_name is already an alias
if resolved_name:
entity_key = resolved_name
elif called_name in alias_map:
# The called_name itself is an alias of an existing entity
entity_key = alias_map[called_name]
else:
# No match found, use the original called name
entity_key = called_name
if entity_key not in self.entities:
self.logger.debug(f"Registering new called entity '{entity_key}' (called as '{called_name}') in chunk {chunk_id}")
self.entities[entity_key] = {
"declaring_chunk_ids": [],
"calling_chunk_ids": [],
"type": [],
"dtype": None,
"aliases": []
}
# Add called_name as an alias if it's different from entity_key
if called_name != entity_key:
self.entities[entity_key]["aliases"].append(called_name)
alias_map[called_name] = entity_key # Update alias map
if chunk_id not in self.entities[entity_key]["calling_chunk_ids"]:
self.entities[entity_key]["calling_chunk_ids"].append(chunk_id)
if resolved_name and resolved_name != called_name:
resolved_count += 1
self.logger.debug(f"Called entity '{called_name}' resolved to '{entity_key}' in chunk {chunk_id}")
self.logger.info(f"Resolved {resolved_count} entity calls to existing declarations via aliases")
self.logger.info("All chunk nodes have been created for all files.")
return chunk_info
def _create_file_nodes(self, chunk_info, level1_node_contents):
self.logger.info("Starting file node creation.")
"""
For each file, aggregate chunk information and create FileNode objects.
This method remains mostly the same.
"""
def merge_entities(target, source):
# Merge entity lists, avoiding duplicates by (name, type)
existing = set((e.get('name'), e.get('type')) for e in target)
for e in source:
k = (e.get('name'), e.get('type'))
if k not in existing:
target.append(e)
existing.add(k)
def merge_called_entities(target, source):
# Merge called entity lists, avoiding duplicates
existing = set(target)
for e in source:
if e not in existing:
target.append(e)
existing.add(e)
file_info = {}
for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Creating file nodes"):
self.logger.info(f"Creating file node for: {file_path}")
parts = os.path.normpath(file_path).split(os.sep)
# Extract file-level entities and chunk results from the stored data
chunk_results = file_data['chunk_results']
file_declared_entities = list(file_data['file_declared_entities']) # Use file-level entities directly
file_called_entities = list(file_data['file_called_entities']) # Use file-level entities directly
chunk_ids = []
for chunk_id, chunk_node, declared_entities, called_entities in chunk_results:
self.logger.info(f"Adding chunk node {chunk_id} to graph for file {file_path}")
self.graph.add_node(chunk_id, data=chunk_node, level=2)
chunk_ids.append(chunk_id)
# Note: We're using file-level entities for the FileNode, so we don't need to merge from chunks
# The chunks already have their entities set correctly
file_node = FileNode(
id=file_path,
name=parts[-1],
path=file_path,
node_type='file',
content=level1_node_contents[file_path],
declared_entities=file_declared_entities,
called_entities=file_called_entities,
language=get_language_from_filename(file_path),
)
self.logger.debug(f"Adding file node {file_path} to graph.")
self.graph.add_node(file_path, data=file_node, level=1)
for chunk_id in chunk_ids:
self.graph.add_edge(file_path, chunk_id, relation='contains')
file_info[file_path] = {
'declared_entities': file_declared_entities,
'called_entities': file_called_entities,
'chunk_ids': chunk_ids,
'parts': parts,
}
self.logger.info(f"File node {file_path} added to graph with {len(chunk_ids)} chunks.")
self.logger.info("All file nodes have been created.")
return file_info
def _create_directory_nodes(self, file_info):
self.logger.info("Starting directory node creation.")
"""
For each directory, aggregate file information and create DirectoryNode objects.
Args:
file_info (dict): Mapping file_path -> file info dict.
Returns:
dict: Mapping dir_path -> aggregated entity info.
"""
def merge_entities(target, source):
# Merge entity lists, avoiding duplicates by (name, type)
existing = set((e.get('name'), e.get('type')) for e in target)
for e in source:
k = (e.get('name'), e.get('type'))
if k not in existing:
target.append(e)
existing.add(k)
def merge_called_entities(target, source):
# Merge called entity lists, avoiding duplicates
existing = set(target)
for e in source:
if e not in existing:
target.append(e)
existing.add(e)
dir_agg = {}
for file_path, info in tqdm.tqdm(file_info.items(), desc="Creating directory nodes"):
self.logger.info(f"Processing directory nodes for file: {file_path}")
parts = os.path.normpath(file_path).split(os.sep)
file_declared_entities = info['declared_entities']
file_called_entities = info['called_entities']
current_parent = 'root'
path_accum = ''
for part in parts[:-1]: # Skip file itself
path_accum = os.path.join(path_accum, part) if path_accum else part
if path_accum not in self.graph:
self.logger.info(f"Adding new directory node: {path_accum}")
dir_node = DirectoryNode(id=path_accum, name=part, path=path_accum)
self.graph.add_node(path_accum, data=dir_node, level=1)
self.graph.add_edge(current_parent, path_accum, relation='contains')
if path_accum not in dir_agg:
dir_agg[path_accum] = {'declared_entities': [], 'called_entities': []}
merge_entities(dir_agg[path_accum]['declared_entities'], file_declared_entities)
merge_called_entities(dir_agg[path_accum]['called_entities'], file_called_entities)
current_parent = path_accum
# Connect file to its parent directory
self.graph.add_edge(current_parent, file_path, relation='contains')
self.logger.info("All directory nodes created.")
return dir_agg
def _aggregate_to_root(self, dir_agg):
self.logger.info("Aggregating directory information to root node.")
"""
Aggregate all directory entity information to the root node.
Args:
dir_agg (dict): Mapping dir_path -> aggregated entity info.
"""
def merge_entities(target, source):
# Merge entity lists, avoiding duplicates by (name, type)
existing = set((e.get('name'), e.get('type')) for e in target)
for e in source:
k = (e.get('name'), e.get('type'))
if k not in existing:
target.append(e)
existing.add(k)
def merge_called_entities(target, source):
# Merge called entity lists, avoiding duplicates
existing = set(target)
for e in source:
if e not in existing:
target.append(e)
existing.add(e)
root_node = Node(id='root', name='root', node_type='root')
self.graph.add_node('root', data=root_node, level=0)
root_declared_entities = []
root_called_entities = []
for dir_path, agg in tqdm.tqdm(dir_agg.items(), desc="Aggregating to root"):
node = self.graph.nodes[dir_path]['data']
if not hasattr(node, 'declared_entities'):
node.declared_entities = []
if not hasattr(node, 'called_entities'):
node.called_entities = []
merge_entities(node.declared_entities, agg['declared_entities'])
merge_called_entities(node.called_entities, agg['called_entities'])
merge_entities(root_declared_entities, agg['declared_entities'])
merge_called_entities(root_called_entities, agg['called_entities'])
if not hasattr(root_node, 'declared_entities'):
root_node.declared_entities = []
if not hasattr(root_node, 'called_entities'):
root_node.called_entities = []
merge_entities(root_node.declared_entities, root_declared_entities)
merge_called_entities(root_node.called_entities, root_called_entities)
self.logger.info("Aggregation to root node complete.")
def _build_relationships(self):
self.logger.info("Building relationships between chunk nodes based on entities.")
"""
Build relationships between chunk nodes and entity nodes based on self.entities.
For each entity in self.entities:
1. Create an EntityNode with entity_name as the id
2. Create edges from declaring chunks to entity node (declares relationship)
3. Create edges from entity node to calling chunks (called_by relationship)
4. Resolve called entity names using aliases for better matching
"""
from .Node import EntityNode
edges_created = 0
entity_nodes_created = 0
# Build alias map for quick lookups
self.logger.info("Building entity alias map for call resolution...")
alias_map = build_entity_alias_map(self.entities)
self.logger.info(f"Built alias map with {len(alias_map)} entries")
# First pass: Create all entity nodes
for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Creating entity nodes"):
# Entity type is stored as a list in 'type' key, get first type or empty string
entity_types = info.get('type', [])
entity_type = entity_types[0] if entity_types else ''
declaring_chunks = info.get('declaring_chunk_ids', [])
calling_chunks = info.get('calling_chunk_ids', [])
aliases = info.get('aliases', [])
# Create EntityNode with entity_name as id
entity_node = EntityNode(
id=entity_name,
name=entity_name,
entity_type=entity_type,
declaring_chunk_ids=declaring_chunks,
calling_chunk_ids=calling_chunks,
aliases=aliases
)
# Add entity node to graph
self.graph.add_node(entity_name, data=entity_node, level=3)
entity_nodes_created += 1
# Log aliases for debugging
if aliases:
self.logger.debug(f"Created EntityNode '{entity_name}' with aliases: {aliases}")
# Create edges from declaring chunks to entity node
for declarer_id in declaring_chunks:
if declarer_id in self.graph:
self.graph.add_edge(declarer_id, entity_name, relation='declares')
edges_created += 1
# Create edges from entity node to calling chunks
for caller_id in calling_chunks:
if caller_id in self.graph and caller_id not in declaring_chunks:
self.graph.add_edge(entity_name, caller_id, relation='called_by')
edges_created += 1
# Second pass: Resolve unmatched entity calls using alias matching
self.logger.info("Resolving entity calls using alias matching...")
resolved_calls = 0
for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Resolving entity calls"):
# Skip entities that already have declarations (they were matched directly)
if info.get('declaring_chunk_ids'):
continue
# Try to resolve this called entity to a declared entity using aliases
resolved_name = resolve_entity_call(entity_name, alias_map)
if resolved_name and resolved_name != entity_name:
# Found a match! Update the calling_chunk_ids of the resolved entity
calling_chunks = info.get('calling_chunk_ids', [])
if resolved_name in self.entities:
for caller_id in calling_chunks:
if caller_id in self.graph:
# Add edge from resolved entity to calling chunk
if not self.graph.has_edge(resolved_name, caller_id):
self.graph.add_edge(resolved_name, caller_id, relation='called_by')
edges_created += 1
resolved_calls += 1
self.logger.debug(f"Resolved call: '{entity_name}' -> '{resolved_name}' in chunk {caller_id}")
self.logger.info(f"_build_relationships: Created {entity_nodes_created} entity nodes, "
f"{edges_created} edges, and resolved {resolved_calls} entity calls using aliases.")
def get_entity_by_alias(self, alias: str) -> Optional[str]:
"""
Get the canonical entity name for a given alias.
Args:
alias: An alias of an entity (e.g., 'MyClass' or 'module.MyClass')
Returns:
Canonical entity name if found, None otherwise
"""
alias_map = build_entity_alias_map(self.entities)
return alias_map.get(alias)
def resolve_entity_references(self) -> Dict[str, List[str]]:
"""
Resolve all entity references in the knowledge graph using aliases.
Returns a mapping of unresolved entity calls to their potential matches.
Returns:
Dictionary mapping called entity names to list of potential canonical matches
"""
alias_map = build_entity_alias_map(self.entities)
resolutions = {}
for entity_name, info in self.entities.items():
# Only look at entities that are called but not declared
if not info.get('declaring_chunk_ids') and info.get('calling_chunk_ids'):
resolved = resolve_entity_call(entity_name, alias_map)
if resolved:
resolutions[entity_name] = resolved
return resolutions
def print_tree(self, max_depth=None, start_node_id='root', level=0, prefix=""):
"""
Print the repository tree structure using the graph with 'contains' edges.
Args:
max_depth (int, optional): Maximum depth to print. None = unlimited.
start_node_id (str): ID of the node to start from. Default is 'root'.
level (int): Internal use only (used for recursion).
prefix (str): Internal use only (used for formatting output).
"""
if max_depth is not None and level > max_depth:
self.logger.debug(f"Max depth {max_depth} reached at node {start_node_id}.")
return
if start_node_id not in self.graph:
self.logger.warning(f"Start node '{start_node_id}' not found in graph.")
return
try:
node_data = self[start_node_id]
except KeyError as e:
self.logger.error(f"KeyError when accessing node {start_node_id}: {e}")
self.logger.error(f"Available node attributes: {list(self.graph.nodes[start_node_id].keys())}")
# Use a fallback approach if 'data' is missing
if 'data' not in self.graph.nodes[start_node_id]:
self.logger.warning(f"Node {start_node_id} has no 'data' attribute, using node itself")
# Create a fallback node if 'data' is missing
if start_node_id == 'root':
# Create a default root node
node_data = Node(id='root', name='root', node_type='root')
# Update the graph node with the fallback data
self.graph.nodes[start_node_id]['data'] = node_data
else:
# Try to infer node type from ID or structure
name = start_node_id.split('/')[-1] if '/' in start_node_id else start_node_id
if '_' in start_node_id and start_node_id.split('_')[-1].isdigit():
# Looks like a chunk ID
node_data = ChunkNode(id=start_node_id, name=name, node_type='chunk')
elif '.' in name:
# Looks like a file
node_data = FileNode(id=start_node_id, name=name, node_type='file', path=start_node_id)
else:
# Fallback to directory or generic node
node_data = DirectoryNode(id=start_node_id, name=name, node_type='directory',
path=start_node_id)
# Update the graph node with the fallback data
self.graph.nodes[start_node_id]['data'] = node_data
return
# Choose icon based on node type
if node_data.node_type == 'file':
node_symbol = "πŸ“„"
elif node_data.node_type == 'chunk':
node_symbol = "πŸ“"
elif node_data.node_type == 'root':
node_symbol = "πŸ“"
elif node_data.node_type == 'directory':
node_symbol = "πŸ“‚"
else:
node_symbol = "πŸ“¦"
if level == 0:
print(f"{node_symbol} {node_data.name} ({node_data.node_type})")
else:
print(f"{prefix}└── {node_symbol} {node_data.name} ({node_data.node_type})")
# Get children via 'contains' edges
children = [
child for child in self.graph.successors(start_node_id)
if self.graph.edges[start_node_id, child].get('relation') == 'contains'
]
child_count = len(children)
for i, child_id in enumerate(children):
is_last = i == child_count - 1
new_prefix = prefix + (" " if is_last else "β”‚ ")
self.print_tree(max_depth, start_node_id=child_id, level=level + 1, prefix=new_prefix)
def to_dict(self):
self.logger.info("Serializing graph to dictionary.")
from .Node import EntityNode
graph_data = {
'nodes': [],
'edges': []
}
for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes"):
if 'data' not in node_attrs:
self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping in serialization")
continue
node = node_attrs['data']
node_dict = {
'id': node.id or node_id,
'class': node.__class__.__name__,
'data': {
'id': node.id or node_id,
'name': node.name,
'node_type': node.node_type,
'description': getattr(node, 'description', ''),
'declared_entities': list(getattr(node, 'declared_entities', [])),
'called_entities': list(getattr(node, 'called_entities', [])),
}
}
# FileNode-specific
if isinstance(node, FileNode):
node_dict['data']['path'] = node.path
node_dict['data']['content'] = node.content
node_dict['data']['language'] = getattr(node, 'language', '')
# ChunkNode-specific
if isinstance(node, ChunkNode):
node_dict['data']['order_in_file'] = getattr(node, 'order_in_file', 0)
node_dict['data']['embedding'] = getattr(node, 'embedding', None)
# EntityNode-specific
if isinstance(node, EntityNode):
node_dict['data']['entity_type'] = getattr(node, 'entity_type', '')
node_dict['data']['declaring_chunk_ids'] = list(getattr(node, 'declaring_chunk_ids', []))
node_dict['data']['calling_chunk_ids'] = list(getattr(node, 'calling_chunk_ids', []))
node_dict['data']['aliases'] = list(getattr(node, 'aliases', []))
graph_data['nodes'].append(node_dict)
for u, v, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges"):
edge_data = {
'source': u,
'target': v,
'relation': attrs.get('relation', '')
}
if 'entities' in attrs:
edge_data['entities'] = list(attrs['entities'])
graph_data['edges'].append(edge_data)
self.logger.info("Serialization complete.")
return graph_data
@classmethod
def from_dict(cls, data_dict, index_nodes: bool = True, use_embed: bool = True,
model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
# ...existing code...
instance = cls.__new__(cls) # bypass __init__
instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
instance.logger.info("Deserializing graph from dictionary.")
node_classes = {
'Node': Node,
'FileNode': FileNode,
'ChunkNode': ChunkNode,
'DirectoryNode': DirectoryNode,
'EntityNode': EntityNode,
}
# Create a root node if not present in the data
root_found = any(node_data['id'] == 'root' for node_data in data_dict['nodes'])
if not root_found:
instance.logger.warning("Root node not found in the data, creating one")
root_node = Node(id='root', name='root', node_type='root')
instance.graph.add_node('root', data=root_node, level=0)
# --- Rebuild nodes ---
for node_data in tqdm.tqdm(data_dict['nodes'], desc="Rebuilding nodes"):
cls_name = node_data['class']
node_cls = node_classes.get(cls_name, Node)
kwargs = node_data['data']
# Ensure ID is properly set
if not kwargs.get('id'):
kwargs['id'] = node_data['id']
# Always use lists for declared_entities and called_entities
kwargs['declared_entities'] = list(kwargs.get('declared_entities', []))
kwargs['called_entities'] = list(kwargs.get('called_entities', []))
# FileNode-specific
if node_cls in (FileNode, ChunkNode):
kwargs.setdefault('path', '')
kwargs.setdefault('content', '')
kwargs.setdefault('language', '')
if node_cls == ChunkNode:
kwargs.setdefault('order_in_file', 0)
kwargs.setdefault('embedding', [])
# EntityNode-specific
if node_cls == EntityNode:
kwargs.setdefault('entity_type', '')
kwargs.setdefault('declaring_chunk_ids', [])
kwargs.setdefault('calling_chunk_ids', [])
kwargs.setdefault('aliases', [])
node_instance = node_cls(**kwargs)
instance.graph.add_node(node_data['id'], data=node_instance, level=instance._infer_level(node_instance))
# --- Rebuild edges ---
for edge in tqdm.tqdm(data_dict['edges'], desc="Rebuilding edges"):
source = edge['source']
target = edge['target']
if source in instance.graph and target in instance.graph:
edge_kwargs = {'relation': edge.get('relation', '')}
if 'entities' in edge:
edge_kwargs['entities'] = list(edge['entities'])
instance.graph.add_edge(source, target, **edge_kwargs)
else:
instance.logger.warning(f"Cannot add edge {source} -> {target}, nodes don't exist")
# --- Rebuild instance.entities ---
instance.entities = {}
for node_id, node_attrs in tqdm.tqdm(instance.graph.nodes(data=True), desc="Rebuilding entities"):
node = node_attrs['data']
declared_entities = getattr(node, 'declared_entities', [])
called_entities = getattr(node, 'called_entities', [])
for entity in declared_entities:
if isinstance(entity, dict):
name = entity.get('name')
else:
name = entity
if not name:
continue
if name not in instance.entities:
instance.entities[name] = {
"declaring_chunk_ids": [],
"calling_chunk_ids": [],
"type": [],
"dtype": None
}
# Only add node_id if it is a ChunkNode
if node_id not in instance.entities[name]["declaring_chunk_ids"]:
if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode):
instance.entities[name]["declaring_chunk_ids"].append(node_id)
if isinstance(entity, dict):
entity_type = entity.get("type")
if entity_type and entity_type not in instance.entities[name]["type"]:
instance.entities[name]["type"].append(entity_type)
dtype = entity.get("dtype")
if dtype:
instance.entities[name]["dtype"] = dtype
for called_name in called_entities:
if not called_name:
continue
if called_name not in instance.entities:
instance.entities[called_name] = {
"declaring_chunk_ids": [],
"calling_chunk_ids": [],
"type": [],
"dtype": None
}
if node_id not in instance.entities[called_name]["calling_chunk_ids"]:
if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode):
instance.entities[called_name]["calling_chunk_ids"].append(node_id)
if index_nodes:
instance.logger.info("Building code index after deserialization.")
# Merge use_embed with code_index_kwargs, avoiding duplicate keyword arguments
code_idx_kwargs = code_index_kwargs or {}
if 'use_embed' not in code_idx_kwargs:
code_idx_kwargs['use_embed'] = use_embed
instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **code_idx_kwargs)
instance.logger.info("Deserialization complete.")
return instance
def _infer_level(self, node):
"""Infer the level of a node based on its type"""
if node.node_type == 'root':
return 0
elif node.node_type in ('file', 'directory'):
return 1
elif node.node_type == 'chunk':
return 2
return 1 # Default level
def save_graph_to_file(self, filepath: str):
self.logger.info(f"Saving graph to file: {filepath}")
with open(filepath, 'w') as f:
json.dump(self.to_dict(), f, indent=2)
self.logger.info("Graph saved successfully.")
@classmethod
def load_graph_from_file(cls, filepath: str, index_nodes=True, use_embed: bool = True,
model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None):
if model_service_kwargs is None:
model_service_kwargs = {}
with open(filepath, 'r') as f:
data = json.load(f)
logging.getLogger(LOGGER_NAME).info(f"Loaded graph data from file: {filepath}")
return cls.from_dict(data, use_embed=use_embed, index_nodes=index_nodes,
model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs)
def to_hf_dataset(
self,
repo_id: str,
save_embeddings: bool = True,
private: bool = False,
token: Optional[str] = None,
commit_message: Optional[str] = None,
):
"""
Save the knowledge graph to a HuggingFace dataset on the Hub.
The graph is serialized into two splits:
- 'nodes': Contains all node data
- 'edges': Contains all edge relationships
Args:
repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name')
save_embeddings (bool): If True, saves embedding vectors for chunk nodes.
If False, embeddings are excluded to reduce dataset size.
private (bool): Whether the dataset should be private. Defaults to False.
token (str, optional): HuggingFace API token. If not provided, uses the token
from huggingface_hub login or HF_TOKEN environment variable.
commit_message (str, optional): Custom commit message for the upload.
Returns:
str: URL of the uploaded dataset
"""
try:
from datasets import Dataset, DatasetDict
from huggingface_hub import HfApi
except ImportError:
raise ImportError(
"huggingface_hub and datasets are required for HuggingFace integration. "
"Install them with: pip install huggingface_hub datasets"
)
self.logger.info(f"Preparing to save knowledge graph to HuggingFace dataset: {repo_id}")
self.logger.info(f"save_embeddings={save_embeddings}")
# Serialize nodes
nodes_data = []
for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes for HF dataset"):
if 'data' not in node_attrs:
self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping")
continue
node = node_attrs['data']
node_record = {
'node_id': node.id or node_id,
'node_class': node.__class__.__name__,
'name': node.name,
'node_type': node.node_type,
'description': getattr(node, 'description', '') or '',
'declared_entities': json.dumps(list(getattr(node, 'declared_entities', []))),
'called_entities': json.dumps(list(getattr(node, 'called_entities', []))),
}
# FileNode-specific fields
if isinstance(node, FileNode):
node_record['path'] = node.path
node_record['content'] = node.content
node_record['language'] = getattr(node, 'language', '')
else:
node_record['path'] = ''
node_record['content'] = ''
node_record['language'] = ''
# ChunkNode-specific fields
if isinstance(node, ChunkNode):
node_record['order_in_file'] = getattr(node, 'order_in_file', 0)
if save_embeddings:
embedding = getattr(node, 'embedding', None)
node_record['embedding'] = json.dumps(embedding if embedding is not None else [])
else:
node_record['embedding'] = json.dumps([])
else:
node_record['order_in_file'] = -1
node_record['embedding'] = json.dumps([])
# EntityNode-specific fields
if isinstance(node, EntityNode):
node_record['entity_type'] = getattr(node, 'entity_type', '')
node_record['declaring_chunk_ids'] = json.dumps(list(getattr(node, 'declaring_chunk_ids', [])))
node_record['calling_chunk_ids'] = json.dumps(list(getattr(node, 'calling_chunk_ids', [])))
node_record['aliases'] = json.dumps(list(getattr(node, 'aliases', [])))
else:
node_record['entity_type'] = ''
node_record['declaring_chunk_ids'] = json.dumps([])
node_record['calling_chunk_ids'] = json.dumps([])
node_record['aliases'] = json.dumps([])
nodes_data.append(node_record)
# Serialize edges
edges_data = []
for source, target, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges for HF dataset"):
edge_record = {
'source': source,
'target': target,
'relation': attrs.get('relation', ''),
'entities': json.dumps(list(attrs.get('entities', []))) if 'entities' in attrs else json.dumps([])
}
edges_data.append(edge_record)
# Create datasets
nodes_dataset = Dataset.from_list(nodes_data)
edges_dataset = Dataset.from_list(edges_data)
self.logger.info(f"Created dataset with {len(nodes_data)} nodes and {len(edges_data)} edges")
# Push to Hub - nodes and edges are pushed separately as different configs
# because they have different schemas
if commit_message is None:
base_commit_message = f"Upload knowledge graph ({len(nodes_data)} nodes, {len(edges_data)} edges)"
if not save_embeddings:
base_commit_message += " [embeddings excluded]"
else:
base_commit_message = commit_message
self.logger.info(f"Pushing nodes dataset to HuggingFace Hub: {repo_id}")
nodes_dataset.push_to_hub(
repo_id=repo_id,
config_name="nodes",
private=private,
token=token,
commit_message=f"{base_commit_message} - nodes"
)
self.logger.info(f"Pushing edges dataset to HuggingFace Hub: {repo_id}")
edges_dataset.push_to_hub(
repo_id=repo_id,
config_name="edges",
private=private,
token=token,
commit_message=f"{base_commit_message} - edges"
)
url = f"https://huggingface.co/datasets/{repo_id}"
self.logger.info(f"Dataset successfully uploaded to: {url}")
return url
@classmethod
def from_hf_dataset(
cls,
repo_id: str,
index_nodes: bool = True,
use_embed: bool = True,
model_service_kwargs: Optional[dict] = None,
code_index_kwargs: Optional[dict] = None,
token: Optional[str] = None,
revision: Optional[str] = None,
):
"""
Load a knowledge graph from a HuggingFace dataset on the Hub.
Args:
repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name')
index_nodes (bool): Whether to build a code index after loading. Defaults to True.
use_embed (bool): Whether to use existing embeddings from the dataset. Defaults to True.
model_service_kwargs (dict, optional): Arguments for the model service.
code_index_kwargs (dict, optional): Arguments for the code index.
token (str, optional): HuggingFace API token for private datasets.
revision (str, optional): Git revision (branch, tag, or commit) to load from.
Returns:
RepoKnowledgeGraph: The loaded knowledge graph instance.
"""
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"datasets library is required for HuggingFace integration. "
"Install it with: pip install datasets"
)
if model_service_kwargs is None:
model_service_kwargs = {}
logger = logging.getLogger(LOGGER_NAME)
logger.info(f"Loading knowledge graph from HuggingFace dataset: {repo_id}")
# Load dataset from Hub - nodes and edges are stored as separate configs
logger.info("Loading nodes config...")
nodes_dataset = load_dataset(repo_id, name="nodes", token=token, revision=revision)
logger.info("Loading edges config...")
edges_dataset = load_dataset(repo_id, name="edges", token=token, revision=revision)
# Get the train split (default split when pushing with config_name)
nodes_data = nodes_dataset['train']
edges_data = edges_dataset['train']
logger.info(f"Loaded {len(nodes_data)} nodes and {len(edges_data)} edges from dataset")
# Convert to the dict format expected by from_dict
graph_data = {
'nodes': [],
'edges': []
}
# Reconstruct nodes
for record in tqdm.tqdm(nodes_data, desc="Reconstructing nodes from HF dataset"):
node_dict = {
'id': record['node_id'],
'class': record['node_class'],
'data': {
'id': record['node_id'],
'name': record['name'],
'node_type': record['node_type'],
'description': record['description'],
'declared_entities': json.loads(record['declared_entities']),
'called_entities': json.loads(record['called_entities']),
}
}
# FileNode-specific fields
if record['node_class'] in ('FileNode', 'ChunkNode'):
node_dict['data']['path'] = record['path']
node_dict['data']['content'] = record['content']
node_dict['data']['language'] = record['language']
# ChunkNode-specific fields
if record['node_class'] == 'ChunkNode':
node_dict['data']['order_in_file'] = record['order_in_file']
embedding = json.loads(record['embedding'])
# Only use embedding if use_embed is True and embedding is non-empty
if use_embed and embedding:
node_dict['data']['embedding'] = embedding
else:
node_dict['data']['embedding'] = []
# EntityNode-specific fields
if record['node_class'] == 'EntityNode':
node_dict['data']['entity_type'] = record['entity_type']
node_dict['data']['declaring_chunk_ids'] = json.loads(record['declaring_chunk_ids'])
node_dict['data']['calling_chunk_ids'] = json.loads(record['calling_chunk_ids'])
node_dict['data']['aliases'] = json.loads(record['aliases'])
graph_data['nodes'].append(node_dict)
# Reconstruct edges
for record in tqdm.tqdm(edges_data, desc="Reconstructing edges from HF dataset"):
edge_dict = {
'source': record['source'],
'target': record['target'],
'relation': record['relation'],
}
entities = json.loads(record['entities'])
if entities:
edge_dict['entities'] = entities
graph_data['edges'].append(edge_dict)
logger.info("Dataset reconstruction complete, building graph...")
# Use from_dict to build the graph
return cls.from_dict(
graph_data,
index_nodes=index_nodes,
use_embed=use_embed,
model_service_kwargs=model_service_kwargs,
code_index_kwargs=code_index_kwargs
)
def get_neighbors(self, node_id):
self.logger.debug(f"Getting neighbors for node: {node_id}")
# Return all nodes that are directly connected to node_id (successors and predecessors) for any edge type
neighbors = set()
for n in self.graph.successors(node_id):
neighbors.add(n)
for n in self.graph.predecessors(node_id):
neighbors.add(n)
# Also include nodes connected by any edge (not just 'contains')
for u, v in self.graph.edges(node_id):
if u == node_id:
neighbors.add(v)
else:
neighbors.add(u)
for u, v in self.graph.in_edges(node_id):
if v == node_id:
neighbors.add(u)
else:
neighbors.add(v)
return [self.graph.nodes[n]['data'] for n in neighbors if 'data' in self.graph.nodes[n]]
def get_previous_chunk(self, node_id: str) -> ChunkNode:
self.logger.debug(f"Getting previous chunk for node: {node_id}")
node = self[node_id]
# Check if node is of type ChunkNode
if not isinstance(node, ChunkNode):
raise Exception(f'Cannot get previous chunk on node of type {type(node)}')
if node.order_in_file == 0:
self.logger.warning(f'Cannot get previous chunk for first node')
return None
file_path = node.path
previous_chunk_id = f'{file_path}_{node.order_in_file - 1}'
if previous_chunk_id not in self.graph:
raise Exception(f'Previous chunk {previous_chunk_id} not found in graph')
previous_chunk = self[previous_chunk_id]
return previous_chunk
def get_next_chunk(self, node_id: str) -> ChunkNode:
self.logger.debug(f"Getting next chunk for node: {node_id}")
node = self[node_id]
# Check if node is of type ChunkNode
if not isinstance(node, ChunkNode):
raise Exception(f'Cannot get previous chunk on node of type {type(node)}')
file_path = node.path
next_chunk_id = f'{file_path}_{node.order_in_file + 1}'
if next_chunk_id not in self.graph:
self.logger.warning(f'Next chunk {next_chunk_id} not found in graph, it might be the last chunk')
return None
previous_chunk = self[next_chunk_id]
return previous_chunk
def get_all_chunks(self) -> List[ChunkNode]:
self.logger.debug("Getting all chunk nodes.")
chunk_nodes = []
for node in self:
if isinstance(node, ChunkNode):
chunk_nodes.append(node)
return chunk_nodes
def get_all_files(self) -> List[FileNode]:
self.logger.debug("Getting all file nodes.")
"""
Get all FileNodes in the knowledge graph.
Returns:
List[FileNode]: A list of FileNodes in the graph.
"""
file_nodes = []
for node in self.graph.nodes(data=True):
node_data = node[1]['data']
# Check for exact FileNode type, not ChunkNode (which inherits from FileNode)
if isinstance(node_data, FileNode) and node_data.node_type == 'file':
file_nodes.append(node_data)
return file_nodes
def get_chunks_of_file(self, file_node_id: str) -> List[ChunkNode]:
self.logger.debug(f"Getting chunks for file node: {file_node_id}")
"""
Get all ChunkNodes associated with a specific FileNode.
Args:
file_node (FileNode): The file node to get chunks for.
Returns:
List[ChunkNode]: A list of ChunkNodes associated with the file.
"""
chunk_nodes = []
for node in self.graph.neighbors(file_node_id):
# Only include ChunkNodes that are connected by a 'contains' edge
edge_data = self.graph.get_edge_data(file_node_id, node)
node_data = self.graph.nodes[node]['data']
if (
isinstance(node_data, ChunkNode)
and node_data.node_type == 'chunk'
and edge_data is not None
and edge_data.get('relation') == 'contains'
):
chunk_nodes.append(node_data)
return chunk_nodes
def find_path(self, source_id: str, target_id: str, max_depth: int = 5) -> dict:
"""
Find the shortest path between two nodes in the knowledge graph.
Args:
source_id (str): The ID of the source node.
target_id (str): The ID of the target node.
max_depth (int): Maximum depth to search for a path. Defaults to 5.
Returns:
dict: A dictionary containing path information or error message.
"""
self.logger.debug(f"Finding path from {source_id} to {target_id} with max_depth={max_depth}")
g = self.graph
if source_id not in g:
return {"error": f"Source node '{source_id}' not found."}
if target_id not in g:
return {"error": f"Target node '{target_id}' not found."}
try:
path = nx.shortest_path(g, source=source_id, target=target_id)
if len(path) - 1 > max_depth:
return {
"source_id": source_id,
"target_id": target_id,
"path": [],
"length": len(path) - 1,
"text": f"Path exists but exceeds max_depth of {max_depth} (actual length: {len(path) - 1})"
}
# Build detailed path information
path_details = []
for i, node_id in enumerate(path):
node = g.nodes[node_id]['data']
node_info = {
"node_id": node_id,
"name": getattr(node, 'name', 'Unknown'),
"type": getattr(node, 'node_type', 'Unknown'),
"step": i
}
# Add edge information for all but the last node
if i < len(path) - 1:
next_node_id = path[i + 1]
edge_data = g.get_edge_data(node_id, next_node_id)
node_info["edge_to_next"] = edge_data.get('relation', 'Unknown') if edge_data else 'Unknown'
path_details.append(node_info)
# Format text output
text = f"Path from '{source_id}' to '{target_id}' (length: {len(path) - 1}):\n\n"
for i, node_info in enumerate(path_details):
text += f"{i}. {node_info['name']} ({node_info['type']})\n"
text += f" Node ID: {node_info['node_id']}\n"
if 'edge_to_next' in node_info:
text += f" --[{node_info['edge_to_next']}]--> \n"
return {
"source_id": source_id,
"target_id": target_id,
"path": path_details,
"length": len(path) - 1,
"text": text
}
except nx.NetworkXNoPath:
return {
"source_id": source_id,
"target_id": target_id,
"path": [],
"length": -1,
"text": f"No path found between '{source_id}' and '{target_id}'"
}
except Exception as e:
self.logger.error(f"Error finding path: {str(e)}")
return {"error": f"Error finding path: {str(e)}"}
def get_subgraph(self, node_id: str, depth: int = 2, edge_types: Optional[List[str]] = None) -> dict:
"""
Extract a subgraph around a node up to a specified depth.
Args:
node_id (str): The ID of the central node.
depth (int): The depth/radius of the subgraph to extract. Defaults to 2.
edge_types (Optional[List[str]]): Optional list of edge types to include (e.g., ['calls', 'contains']).
Returns:
dict: A dictionary containing subgraph information or error message.
"""
self.logger.debug(f"Getting subgraph for node {node_id} with depth={depth}, edge_types={edge_types}")
g = self.graph
if node_id not in g:
return {"error": f"Node '{node_id}' not found."}
# Collect nodes within specified depth
nodes_at_depth = {node_id}
all_nodes = {node_id}
for d in range(depth):
next_level = set()
for n in nodes_at_depth:
# Get all neighbors (both incoming and outgoing)
for neighbor in g.successors(n):
if edge_types is None:
next_level.add(neighbor)
else:
edge_data = g.get_edge_data(n, neighbor)
if edge_data and edge_data.get('relation') in edge_types:
next_level.add(neighbor)
for neighbor in g.predecessors(n):
if edge_types is None:
next_level.add(neighbor)
else:
edge_data = g.get_edge_data(neighbor, n)
if edge_data and edge_data.get('relation') in edge_types:
next_level.add(neighbor)
nodes_at_depth = next_level - all_nodes
all_nodes.update(next_level)
# Extract subgraph
subgraph = g.subgraph(all_nodes).copy()
# Build node details
nodes = []
for n in subgraph.nodes():
node = subgraph.nodes[n]['data']
nodes.append({
"node_id": n,
"name": getattr(node, 'name', 'Unknown'),
"type": getattr(node, 'node_type', 'Unknown')
})
# Build edge details
edges = []
for source, target, data in subgraph.edges(data=True):
edges.append({
"source": source,
"target": target,
"relation": data.get('relation', 'Unknown')
})
# Format text output
text = f"Subgraph around '{node_id}' (depth: {depth}):\n"
if edge_types:
text += f"Edge types filter: {', '.join(edge_types)}\n"
text += f"\nNodes: {len(nodes)}\n"
text += f"Edges: {len(edges)}\n\n"
# Group nodes by type
nodes_by_type = {}
for node in nodes:
node_type = node['type']
if node_type not in nodes_by_type:
nodes_by_type[node_type] = []
nodes_by_type[node_type].append(node)
for node_type, type_nodes in nodes_by_type.items():
text += f"{node_type} ({len(type_nodes)}):\n"
for node in type_nodes[:5]:
text += f" - {node['name']} ({node['node_id']})\n"
if len(type_nodes) > 5:
text += f" ... and {len(type_nodes) - 5} more\n"
text += "\n"
# Show edge statistics
edge_by_relation = {}
for edge in edges:
relation = edge['relation']
edge_by_relation[relation] = edge_by_relation.get(relation, 0) + 1
if edge_by_relation:
text += "Edge types:\n"
for relation, count in edge_by_relation.items():
text += f" - {relation}: {count}\n"
return {
"center_node_id": node_id,
"depth": depth,
"edge_types_filter": edge_types,
"node_count": len(nodes),
"edge_count": len(edges),
"nodes": nodes,
"edges": edges,
"nodes_by_type": nodes_by_type,
"edge_by_relation": edge_by_relation,
"text": text
}