Spaces:
Build error
Build error
import streamlit as st | |
import streamlit_antd_components as sac | |
import pandas as pd | |
from streamlit_monaco import st_monaco | |
import os | |
from pygwalker.api.streamlit import StreamlitRenderer | |
import networkx as nx | |
from streamlit_agraph import agraph, Node, Edge, Config | |
import numpy as np | |
def get_pyg_renderer(df,spec) -> StreamlitRenderer: | |
if os.path.exists(f"./{spec}_config.json"): | |
os.remove(f"./{spec}_config.json") | |
return StreamlitRenderer(df, spec=f"./{spec}_config.json", spec_io_mode="rw") | |
def get_nodes(df:pd.DataFrame,source,target): | |
nodes = [] | |
unique_node_id = [] | |
df.replace(np.nan, None, regex=True,inplace=True) | |
# remove rows with null source or target | |
df = df.dropna(subset=[source,target]) | |
for index, row in df.iterrows(): | |
cols = df.columns | |
if row[source] is not None or row[source] != "" or row[source] != np.nan: | |
meta_data = {} | |
for col in cols: | |
if col != source and col != target and col != "relation": | |
if row[col] is not None: | |
meta_data[col] = row[col] | |
if str(row[source]).replace(".0","") not in unique_node_id: | |
unique_node_id.append(str(row[source]).replace(".0","")) | |
nodes.append(Node( | |
id=str(row[source]).replace(".0",""), | |
label=str(row[source]).replace(".0",""), | |
size=10, | |
color="red", | |
)) | |
if row[target] is not None or row[target] != "" or row[target] != np.nan: | |
meta_data = {} | |
for col in cols: | |
if col != source and col != target and col != "relation": | |
if row[col] is not None: | |
meta_data[col] = row[col] | |
if str(row[target]).replace(".0","") not in unique_node_id: | |
unique_node_id.append(str(row[target]).replace(".0","")) | |
nodes.append(Node( | |
id=str(row[target]).replace(".0",""), | |
label=str(row[target]).replace(".0",""), | |
size=10, | |
color="red", | |
)) | |
return nodes | |
def get_edges(df:pd.DataFrame,source,target): | |
edges = [] | |
df.replace(np.nan, None, regex=True,inplace=True) | |
for index, row in df.iterrows(): | |
if str(row[source]).replace(".0","") is not None and str(row[target]).replace(".0","") is not None: | |
edges.append(Edge( | |
source=str(row[source]).replace(".0",""), | |
target=str(row[target]).replace(".0",""), | |
label=row["relation"], | |
size=1, | |
)) | |
return edges | |
def analyze(): | |
st.title('Analyze') | |
col = st.columns([4, 1]) | |
with col[0]: | |
st.subheader("Analyze loaded data and visualize it.") | |
with col[1]: | |
query_mode = st.radio("Query Mode",["Visual","Raw"],horizontal=True,index=0,key="query_type") | |
query_controls = st.columns([1,1,1,1,1]) | |
with query_controls[0]: | |
saved_queries = st.session_state.sna.get_queries() | |
if saved_queries is None: | |
saved_queries = {} | |
selected_query = st.selectbox("Select Query",saved_queries.keys(),index=None,key="selected_query") | |
with query_controls[1]: | |
if selected_query is None: | |
query_name = st.text_input(key="query_name",value="New Query",label="Query Name") | |
if query_name not in st.session_state: | |
st.session_state[query_name] = "MATCH (n) RETURN n LIMIT 25" | |
st.session_state.sna.save_query(query_name,st.session_state[query_name]) | |
st.rerun() | |
else: | |
query_name = st.text_input(key="query_name",value=selected_query,label="Query Name") | |
st.session_state[query_name] = saved_queries[selected_query] | |
helpers = st.session_state.sna.get_helper() | |
if helpers is not None: | |
files = [] | |
for helper in helpers: | |
if not helper.endswith(".json"): | |
files.append(helper.split("/")[-1]) | |
select_files_cols = st.columns([1,1,1,1,1]) | |
with select_files_cols[0]: | |
selected_files = st.selectbox("Select Loaded File",files) | |
if selected_files is not None: | |
for helper in helpers: | |
if helper.endswith(selected_files): | |
helper_json = helpers[helper] | |
tables = {} | |
relationships = {} | |
for object in helper_json: | |
if helper_json[object]["isTable"]: | |
tables[helper_json[object]["object_name"]] = helper_json[object]["columns"] | |
elif helper_json[object]["isRelationship"]: | |
relationships[helper_json[object]["object_name"]] = { | |
"source": helper_json[object]["source"], | |
"target": helper_json[object]["target"] | |
} | |
if query_mode == "Raw": | |
query_cols = st.columns([4,6]) | |
with query_cols[0]: | |
with st.expander("Objects",expanded=True): | |
st.write("Tables") | |
st.write(tables) | |
st.write("Relationships") | |
st.write(relationships) | |
with query_cols[1]: | |
with st.expander("Query",expanded=True): | |
if selected_query is None: | |
query_text = st_monaco(value="", height="200px", language="cypher",theme="vs-dark",lineNumbers=True) | |
else: | |
query_text = st_monaco(value=saved_queries[selected_query], height="200px", language="cypher",theme="vs-dark",lineNumbers=True) | |
if query_mode == "Visual": | |
query_text = st.session_state[query_name] | |
node_edge_template = """MATCH | |
q = (s:##source_node##)-[r##relationship##]->(t:##target_node##) | |
""" | |
with st.expander("Node and Edges Selection",expanded=True): | |
st.write("Select the source and target nodes for the network") | |
node_edge_selection_columns = st.columns([1,1,1,1,1]) | |
with node_edge_selection_columns[0]: | |
source_node = st.selectbox("Source Node",list(tables.keys()),index=0,key="source_node") | |
with node_edge_selection_columns[1]: | |
target_node = st.selectbox("Target Node",list(tables.keys()),index=1,key="target_node") | |
if source_node is not None and target_node is not None: | |
st.write("Select the relationship between the source and target nodes") | |
node_edge_selection_columns = st.columns([1,1,1,1,1]) | |
with node_edge_selection_columns[0]: | |
filtered_relationship = {} | |
for relationship in relationships: | |
if relationships[relationship]["source"] == source_node and relationships[relationship]["target"] == target_node: | |
filtered_relationship[relationship] = relationships[relationship] | |
relationship = st.multiselect("Relationships",list(filtered_relationship.keys()),key="relationship") | |
if source_node is not None and target_node is not None and len(relationship) > 0: | |
relationship_str = "" | |
for rel in relationship: | |
relationship_str += f":{rel.strip()}|" | |
relationship_str = relationship_str[:-1] | |
node_edge_template_stg = node_edge_template | |
node_edge = node_edge_template_stg.replace("##source_node##",source_node).replace("##target_node##",target_node).replace("##relationship##",relationship_str) | |
source_primary_key = "" | |
target_primary_key = "" | |
for column in tables[source_node]: | |
if tables[source_node][column] == "INT64": | |
source_primary_key = column | |
break | |
for column in tables[target_node]: | |
if tables[target_node][column] == "INT64": | |
target_primary_key = column | |
break | |
return_template = f"""RETURN (nodes(q)[1]).{source_primary_key} as source, | |
(rels(q)[1])._label as relation, | |
(nodes(q)[2]).{target_primary_key} as target | |
##return##""" | |
with st.expander("Returned Columns"): | |
source_columns = [] | |
target_columns = [] | |
for column in tables[source_node]: | |
if column != source_primary_key: | |
source_columns.append(column) | |
for column in tables[target_node]: | |
if column != target_primary_key: | |
target_columns.append(column) | |
returned_columns = st.columns([3,3,1]) | |
with returned_columns[0]: | |
source_node_columns = st.multiselect("Source Node Columns",source_columns,key="source_node_columns") | |
with returned_columns[1]: | |
target_node_columns = st.multiselect("Target Node Columns",target_columns,key="target_node_columns") | |
with returned_columns[-1]: | |
limit = st.number_input("Limit",min_value=1,max_value=10000,value=200,key="limit") | |
return_str = "" | |
return_template_stg = return_template | |
if len(source_node_columns) > 0: | |
for column in source_node_columns: | |
return_str += f"\n,(nodes(q)[1]).{column} as source_{column}" | |
if len(target_node_columns) > 0: | |
for column in target_node_columns: | |
return_str += f"\n,(nodes(q)[2]).{column} as target_{column}" | |
return_statement = return_template_stg.replace("##return##",return_str) | |
prepared_query = node_edge + return_statement + f"\n LIMIT {limit};" | |
query_text = prepared_query | |
query_actions = st.columns([1,1,1,1,1,1,1,1,1,1,1]) | |
with query_actions[0]: | |
if st.button("Save",key="save_query"): | |
st.session_state["results"] = None | |
st.session_state[query_name] = query_text | |
st.session_state.sna.save_query(query_name,st.session_state[query_name]) | |
with query_actions[1]: | |
if st.button("Execute",type="primary",key="execute_query"): | |
st.session_state["results"] = None | |
st.session_state["results"] = st.session_state.sna.execute_query(query_text) | |
if "results" in st.session_state: | |
with st.expander('Results',expanded=False): | |
result_renderer = get_pyg_renderer(st.session_state["results"],"result") | |
result_renderer.explorer() | |
with st.expander('Network',expanded=False): | |
if st.session_state["results"] is not None: | |
netwok_columns = st.columns([1,1,1,1]) | |
networkx_result = st.session_state.sna.apply_networkx_analysis(st.session_state["results"],"source","target") | |
ergm_result = st.session_state.sna.apply_ergm(st.session_state["results"],"source","target") | |
with netwok_columns[-1]: | |
analysis_type = st.radio("Analysis Type",["Graph","Visualization"],index=0,key="network_type",horizontal=True) | |
if analysis_type == "Visualization": | |
networkx_renderer = get_pyg_renderer(networkx_result,"networkx") | |
networkx_renderer.explorer() | |
elif analysis_type == "Graph": | |
config = Config(width=1000,height=750,directed=True,physics=True,hierarchical=True) | |
nodes = get_nodes(networkx_result,"source","target") | |
edges = get_edges(networkx_result,"source","target") | |
graph_columns = st.columns([9,3]) | |
with graph_columns[0]: | |
return_value = agraph(nodes=nodes, edges=edges, config=config) | |
with graph_columns[1]: | |
st.write("Ergm Analysis") | |
ergm_result_to_display = ergm_result["ergm"] | |
del ergm_result_to_display["degree_distribution"] | |
ergm_result_df = pd.DataFrame.from_dict(ergm_result_to_display,orient="index",columns=["Value"]) | |
st.dataframe(ergm_result_df) | |
if return_value is not None: | |
st.write("Networkx Analysis") | |
with open(f"./output.csv","w") as f: | |
f.write(networkx_result.to_csv(index=False)) | |
networkx_result_search = networkx_result | |
# convert source and target to string | |
networkx_result_search["source"] = networkx_result_search["source"].astype(str).replace(".0","") | |
networkx_result_search["target"] = networkx_result_search["target"].astype(str).replace(".0","") | |
# search for a node in the networkx_result_search with return value | |
search = networkx_result_search[networkx_result_search["source"].str.contains(return_value)] | |
if search.empty: | |
search = networkx_result_search[networkx_result_search["target"].str.contains(return_value)] | |
if not search.empty: | |
st.write(search) | |