Spaces:
Sleeping
Sleeping
import gradio as gr | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from neo4j import GraphDatabase | |
import io | |
import base64 | |
class Neo4jGraphVisualizer: | |
def __init__(self, uri, username, password): | |
""" | |
Initialize the Neo4j graph database connection | |
Args: | |
uri (str): Neo4j database URI | |
username (str): Neo4j username | |
password (str): Neo4j password | |
""" | |
self.driver = GraphDatabase.driver(uri, auth=(username, password)) | |
def fetch_graph_data(self): | |
""" | |
Fetch graph data from Neo4j database | |
Returns: | |
dict: A dictionary containing nodes and relationships | |
""" | |
with self.driver.session() as session: | |
# Fetch all nodes with elementId instead of deprecated ID() | |
nodes_result = session.run(""" | |
MATCH (n) | |
RETURN elementId(n) as id, | |
labels(n) as labels, | |
properties(n) as properties | |
""") | |
# Fetch all relationships using elementId | |
relationships_result = session.run(""" | |
MATCH (a)-[r]->(b) | |
RETURN | |
elementId(a) as source_id, | |
elementId(b) as target_id, | |
type(r) as relationship_type, | |
properties(r) as relationship_properties | |
""") | |
# Process nodes | |
nodes = [ | |
{ | |
'id': record['id'], | |
'label': record['labels'][0] if record['labels'] else 'Unknown', | |
'properties': dict(record['properties']) | |
} | |
for record in nodes_result | |
] | |
# Process relationships | |
relationships = [ | |
{ | |
'source': record['source_id'], | |
'target': record['target_id'], | |
'type': record['relationship_type'], | |
'properties': dict(record.get('relationship_properties', {})) | |
} | |
for record in relationships_result | |
] | |
return {'nodes': nodes, 'relationships': relationships} | |
def visualize_graph(self): | |
""" | |
Visualize the graph using NetworkX and Matplotlib | |
Returns: | |
str: Base64 encoded image of the graph | |
""" | |
try: | |
# Fetch graph data | |
graph_data = self.fetch_graph_data() | |
# Create NetworkX graph | |
G = nx.DiGraph() | |
# Add nodes | |
for node in graph_data['nodes']: | |
# Use node's label or properties for display | |
label = node.get('properties', {}).get('name', str(node['id'])) | |
G.add_node(node['id'], | |
label=label, | |
properties=node['properties']) | |
# Add edges | |
for rel in graph_data['relationships']: | |
G.add_edge(rel['source'], rel['target'], | |
type=rel['type'], | |
properties=rel['properties']) | |
# Visualization | |
plt.figure(figsize=(16, 12)) | |
pos = nx.spring_layout(G, k=0.9, iterations=50) # Improved layout | |
# Draw nodes with color and size based on properties | |
node_sizes = [300 + len(str(G.nodes[node]['properties'])) * 10 for node in G.nodes()] | |
node_colors = ['lightblue' if idx % 2 == 0 else 'lightgreen' for idx in range(len(G.nodes()))] | |
nx.draw_networkx_nodes(G, pos, | |
node_color=node_colors, | |
node_size=node_sizes, | |
alpha=0.8) | |
# Draw edges | |
nx.draw_networkx_edges(G, pos, | |
edge_color='gray', | |
arrows=True, | |
width=1.5) | |
# Draw labels | |
nx.draw_networkx_labels(G, pos, | |
labels={node: G.nodes[node]['label'] for node in G.nodes()}, | |
font_size=8) | |
plt.title("Neo4j Graph Visualization") | |
plt.axis('off') | |
# Save to buffer | |
buffer = io.BytesIO() | |
plt.savefig(buffer, format='png', dpi=300, bbox_inches='tight') | |
buffer.seek(0) | |
image_png = buffer.getvalue() | |
buffer.close() | |
plt.close() # Close the plot to free memory | |
# Encode | |
graphic = base64.b64encode(image_png).decode('utf-8') | |
return f"data:image/png;base64,{graphic}" | |
except Exception as e: | |
print(f"Error in graph visualization: {e}") | |
return f"Error visualizing graph: {e}" | |
def close(self): | |
"""Close the Neo4j driver connection""" | |
self.driver.close() | |
def create_gradio_interface(uri, username, password): | |
""" | |
Create a Gradio interface for Neo4j graph visualization | |
Args: | |
uri (str): Neo4j database URI | |
username (str): Neo4j username | |
password (str): Neo4j password | |
""" | |
visualizer = Neo4jGraphVisualizer(uri, username, password) | |
def visualize_graph(): | |
try: | |
graph_image = visualizer.visualize_graph() | |
return graph_image | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=visualize_graph, | |
inputs=[], | |
outputs=gr.Image(type="filepath"), | |
title="Neo4j Graph Visualization", | |
description="Visualize graph data from Neo4j database" | |
) | |
return iface, visualizer | |
# Configuration (replace with your actual Neo4j credentials) | |
NEO4J_URI="neo4j+s://b96332bd.databases.neo4j.io" | |
NEO4J_USERNAME="neo4j" | |
NEO4J_PASSWORD="qviTdN6cw66AjIv6lu7kXcsN4keYPdXc2gAWuIoB8T4" | |
AURA_INSTANCEID="b96332bd" | |
AURA_INSTANCENAME="Instance01" | |
def main(): | |
# Create Gradio interface | |
interface, visualizer = create_gradio_interface( | |
NEO4J_URI, | |
NEO4J_USERNAME, | |
NEO4J_PASSWORD | |
) | |
try: | |
# Launch the interface | |
interface.launch(server_name='0.0.0.0', server_port=7860) | |
except Exception as e: | |
print(f"Error launching interface: {e}") | |
finally: | |
# Ensure driver is closed | |
visualizer.close() | |
if __name__ == "__main__": | |
main() |