Spaces:
Running
on
Zero
Running
on
Zero
import plotly.graph_objects as go | |
import networkx as nx | |
import plotly.graph_objects as go | |
import networkx as nx | |
def create_cytoscape_plot(entities, relationships): | |
G = nx.DiGraph() # Use DiGraph for directed edges | |
for entity_id, entity_data in entities.items(): | |
G.add_node(entity_id, **entity_data) | |
for source, relation, target in relationships: | |
G.add_edge(source, target, relation=relation) | |
pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters | |
edge_trace = go.Scatter( | |
x=[], | |
y=[], | |
line=dict(width=1, color="#888"), | |
hoverinfo="text", | |
mode="lines", | |
text=[], | |
) | |
node_trace = go.Scatter( | |
x=[], | |
y=[], | |
mode="markers+text", | |
hoverinfo="text", | |
marker=dict( | |
showscale=True, | |
colorscale="Viridis", | |
reversescale=True, | |
color=[], | |
size=15, | |
colorbar=dict( | |
thickness=15, | |
title="Node Connections", | |
xanchor="left", | |
titleside="right", | |
), | |
line_width=2, | |
), | |
text=[], | |
textposition="top center", | |
) | |
edge_labels = [] | |
for edge in G.edges(): | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
edge_trace["x"] += (x0, x1, None) | |
edge_trace["y"] += (y0, y1, None) | |
# Calculate midpoint for edge label | |
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2 | |
edge_labels.append( | |
go.Scatter( | |
x=[mid_x], | |
y=[mid_y], | |
mode="text", | |
text=[G.edges[edge]["relation"]], | |
textposition="middle center", | |
hoverinfo="none", | |
showlegend=False, | |
textfont=dict(size=8), | |
) | |
) | |
for node in G.nodes(): | |
x, y = pos[node] | |
node_trace["x"] += (x,) | |
node_trace["y"] += (y,) | |
node_info = f"{entities[node]['value']} ({entities[node]['type']})" | |
node_trace["text"] += (node_info,) | |
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),) | |
fig = go.Figure( | |
data=[edge_trace, node_trace] + edge_labels, | |
layout=go.Layout( | |
title="Knowledge Graph", | |
titlefont_size=16, | |
showlegend=False, | |
hovermode="closest", | |
margin=dict(b=20, l=5, r=5, t=40), | |
annotations=[], | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=800, | |
height=600, | |
), | |
) | |
# Enable dragging of nodes | |
fig.update_layout( | |
newshape=dict(line_color="#009900"), | |
# Enable zoom | |
xaxis=dict( | |
scaleanchor="y", | |
scaleratio=1, | |
), | |
yaxis=dict( | |
scaleanchor="x", | |
scaleratio=1, | |
), | |
) | |
return fig |