Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from typing import Optional | |
import neo4j | |
from domain.entity_d import ( | |
EntityD, | |
EntityKnowledgeGraphD, | |
EntityRelationshipD, | |
RelationshipD, | |
) | |
class Neo4jError(Exception): | |
... | |
class Neo4jDomainDAO: | |
""" To be used with a context manager to ensure the connection is closed after use. """ | |
def __enter__(self): | |
uri = os.environ.get("NEO4J_URI", "") | |
user = os.environ.get("NEO4J_USER", "") | |
password = os.environ.get("NEO4J_PASSWORD", "") | |
if not uri: | |
raise ValueError("NEO4J_URI environment variable not set") | |
if not user: | |
raise ValueError("NEO4J_USER environment variable not set") | |
if not password: | |
raise ValueError("NEO4J_PASSWORD environment variable not set") | |
try: | |
self.driver = neo4j.GraphDatabase.driver(uri, auth=(user, password)) | |
self.driver.verify_connectivity() | |
except Exception as e: | |
logging.error(f"Failed to connect to Neo4j: {e}") | |
raise Neo4jError("Failed to connect to Neo4j") | |
return self | |
def insert(self, knowledge_graph: EntityKnowledgeGraphD, pdf_file: str = ""): | |
for entity_relationship in knowledge_graph.entity_relationships: | |
create_cmds = entity_relationship.neo4j_create_cmds | |
create_cmds_args = entity_relationship.neo4j_create_args | |
for create_cmd, args in zip(create_cmds, create_cmds_args): | |
args['pdf_file'] = pdf_file | |
try: | |
self.driver.execute_query( | |
create_cmd, # type: ignore | |
parameters_=args, # type: ignore | |
database_='neo4j') # type: ignore | |
except Exception as e: | |
logging.warning( | |
f"Failed to insert entity relationship: {entity_relationship} due to {e}") | |
def query(self, query, query_args): | |
return self.driver.execute_query(query, parameters_=query_args, | |
database_='neo4j') # type: ignore | |
def get_knowledge_graph(self) -> Optional[EntityKnowledgeGraphD]: | |
records = [] #list[dict[str, Neo4jDict]] | |
try: | |
records, _, _ = self.driver.execute_query("MATCH (from:Entity) -[r:Relationship]-> (to:Entity) RETURN from, properties(r), to", database_='neo4j') # type: ignore | |
except Exception as e: | |
logging.exception(e) | |
return None | |
entity_relationships = [] | |
for record in records: | |
er_dict = record.data() | |
from_args = er_dict['from'] | |
from_entity = EntityD(entity_id='', entity_name=from_args['name']) | |
to_args = er_dict['to'] | |
to_entity = EntityD(entity_id='', entity_name=to_args['name']) | |
relationship_args = er_dict['properties(r)'] | |
relationship = RelationshipD(relationship_id='', | |
start_date=relationship_args['start_date'], | |
end_date=relationship_args['end_date'], | |
source_text=relationship_args['source_text'], | |
predicted_movement=RelationshipD.from_string( | |
relationship_args['predicted_movement'])) | |
entity_relationships.append( | |
EntityRelationshipD(from_entity=from_entity, | |
relationship=relationship, | |
to_entity=to_entity)) | |
return EntityKnowledgeGraphD(entity_relationships=entity_relationships) | |
def __exit__(self, exception_type, exception_value, traceback): | |
if traceback: | |
logging.error("Neo4jDomainDAO error: %s | %s | %s", | |
exception_type, | |
exception_value, | |
traceback) | |
self.driver.close() | |