visualization / app.py
rockerritesh's picture
Update app.py
d2b6669 verified
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
from neo4j import GraphDatabase
import io
import base64
class Neo4jGraphVisualizer:
def __init__(self, uri, username, password):
"""
Initialize the Neo4j graph database connection
Args:
uri (str): Neo4j database URI
username (str): Neo4j username
password (str): Neo4j password
"""
self.driver = GraphDatabase.driver(uri, auth=(username, password))
def fetch_graph_data(self):
"""
Fetch graph data from Neo4j database
Returns:
dict: A dictionary containing nodes and relationships
"""
with self.driver.session() as session:
# Fetch all nodes with elementId instead of deprecated ID()
nodes_result = session.run("""
MATCH (n)
RETURN elementId(n) as id,
labels(n) as labels,
properties(n) as properties
""")
# Fetch all relationships using elementId
relationships_result = session.run("""
MATCH (a)-[r]->(b)
RETURN
elementId(a) as source_id,
elementId(b) as target_id,
type(r) as relationship_type,
properties(r) as relationship_properties
""")
# Process nodes
nodes = [
{
'id': record['id'],
'label': record['labels'][0] if record['labels'] else 'Unknown',
'properties': dict(record['properties'])
}
for record in nodes_result
]
# Process relationships
relationships = [
{
'source': record['source_id'],
'target': record['target_id'],
'type': record['relationship_type'],
'properties': dict(record.get('relationship_properties', {}))
}
for record in relationships_result
]
return {'nodes': nodes, 'relationships': relationships}
def visualize_graph(self):
"""
Visualize the graph using NetworkX and Matplotlib
Returns:
str: Base64 encoded image of the graph
"""
try:
# Fetch graph data
graph_data = self.fetch_graph_data()
# Create NetworkX graph
G = nx.DiGraph()
# Add nodes
for node in graph_data['nodes']:
# Use node's label or properties for display
label = node.get('properties', {}).get('name', str(node['id']))
G.add_node(node['id'],
label=label,
properties=node['properties'])
# Add edges
for rel in graph_data['relationships']:
G.add_edge(rel['source'], rel['target'],
type=rel['type'],
properties=rel['properties'])
# Visualization
plt.figure(figsize=(16, 12))
pos = nx.spring_layout(G, k=0.9, iterations=50) # Improved layout
# Draw nodes with color and size based on properties
node_sizes = [300 + len(str(G.nodes[node]['properties'])) * 10 for node in G.nodes()]
node_colors = ['lightblue' if idx % 2 == 0 else 'lightgreen' for idx in range(len(G.nodes()))]
nx.draw_networkx_nodes(G, pos,
node_color=node_colors,
node_size=node_sizes,
alpha=0.8)
# Draw edges
nx.draw_networkx_edges(G, pos,
edge_color='gray',
arrows=True,
width=1.5)
# Draw labels
nx.draw_networkx_labels(G, pos,
labels={node: G.nodes[node]['label'] for node in G.nodes()},
font_size=8)
plt.title("Neo4j Graph Visualization")
plt.axis('off')
# Save to buffer
buffer = io.BytesIO()
plt.savefig(buffer, format='png', dpi=300, bbox_inches='tight')
buffer.seek(0)
image_png = buffer.getvalue()
buffer.close()
plt.close() # Close the plot to free memory
# Encode
graphic = base64.b64encode(image_png).decode('utf-8')
return f"data:image/png;base64,{graphic}"
except Exception as e:
print(f"Error in graph visualization: {e}")
return f"Error visualizing graph: {e}"
def close(self):
"""Close the Neo4j driver connection"""
self.driver.close()
def create_gradio_interface(uri, username, password):
"""
Create a Gradio interface for Neo4j graph visualization
Args:
uri (str): Neo4j database URI
username (str): Neo4j username
password (str): Neo4j password
"""
visualizer = Neo4jGraphVisualizer(uri, username, password)
def visualize_graph():
try:
graph_image = visualizer.visualize_graph()
return graph_image
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=visualize_graph,
inputs=[],
outputs=gr.Image(type="filepath"),
title="Neo4j Graph Visualization",
description="Visualize graph data from Neo4j database"
)
return iface, visualizer
# Configuration (replace with your actual Neo4j credentials)
NEO4J_URI="neo4j+s://b96332bd.databases.neo4j.io"
NEO4J_USERNAME="neo4j"
NEO4J_PASSWORD="qviTdN6cw66AjIv6lu7kXcsN4keYPdXc2gAWuIoB8T4"
AURA_INSTANCEID="b96332bd"
AURA_INSTANCENAME="Instance01"
def main():
# Create Gradio interface
interface, visualizer = create_gradio_interface(
NEO4J_URI,
NEO4J_USERNAME,
NEO4J_PASSWORD
)
try:
# Launch the interface
interface.launch(server_name='0.0.0.0', server_port=7860)
except Exception as e:
print(f"Error launching interface: {e}")
finally:
# Ensure driver is closed
visualizer.close()
if __name__ == "__main__":
main()