sna_g34 / app /analyze.py
yashpulse's picture
Final
7861c5d
raw
history blame contribute delete
No virus
14 kB
import streamlit as st
import streamlit_antd_components as sac
import pandas as pd
from streamlit_monaco import st_monaco
import os
from pygwalker.api.streamlit import StreamlitRenderer
import networkx as nx
from streamlit_agraph import agraph, Node, Edge, Config
import numpy as np
def get_pyg_renderer(df,spec) -> StreamlitRenderer:
if os.path.exists(f"./{spec}_config.json"):
os.remove(f"./{spec}_config.json")
return StreamlitRenderer(df, spec=f"./{spec}_config.json", spec_io_mode="rw")
def get_nodes(df:pd.DataFrame,source,target):
nodes = []
unique_node_id = []
df.replace(np.nan, None, regex=True,inplace=True)
# remove rows with null source or target
df = df.dropna(subset=[source,target])
for index, row in df.iterrows():
cols = df.columns
if row[source] is not None or row[source] != "" or row[source] != np.nan:
meta_data = {}
for col in cols:
if col != source and col != target and col != "relation":
if row[col] is not None:
meta_data[col] = row[col]
if str(row[source]).replace(".0","") not in unique_node_id:
unique_node_id.append(str(row[source]).replace(".0",""))
nodes.append(Node(
id=str(row[source]).replace(".0",""),
label=str(row[source]).replace(".0",""),
size=10,
color="red",
))
if row[target] is not None or row[target] != "" or row[target] != np.nan:
meta_data = {}
for col in cols:
if col != source and col != target and col != "relation":
if row[col] is not None:
meta_data[col] = row[col]
if str(row[target]).replace(".0","") not in unique_node_id:
unique_node_id.append(str(row[target]).replace(".0",""))
nodes.append(Node(
id=str(row[target]).replace(".0",""),
label=str(row[target]).replace(".0",""),
size=10,
color="red",
))
return nodes
def get_edges(df:pd.DataFrame,source,target):
edges = []
df.replace(np.nan, None, regex=True,inplace=True)
for index, row in df.iterrows():
if str(row[source]).replace(".0","") is not None and str(row[target]).replace(".0","") is not None:
edges.append(Edge(
source=str(row[source]).replace(".0",""),
target=str(row[target]).replace(".0",""),
label=row["relation"],
size=1,
))
return edges
def analyze():
st.title('Analyze')
col = st.columns([4, 1])
with col[0]:
st.subheader("Analyze loaded data and visualize it.")
with col[1]:
query_mode = st.radio("Query Mode",["Visual","Raw"],horizontal=True,index=0,key="query_type")
query_controls = st.columns([1,1,1,1,1])
with query_controls[0]:
saved_queries = st.session_state.sna.get_queries()
if saved_queries is None:
saved_queries = {}
selected_query = st.selectbox("Select Query",saved_queries.keys(),index=None,key="selected_query")
with query_controls[1]:
if selected_query is None:
query_name = st.text_input(key="query_name",value="New Query",label="Query Name")
if query_name not in st.session_state:
st.session_state[query_name] = "MATCH (n) RETURN n LIMIT 25"
st.session_state.sna.save_query(query_name,st.session_state[query_name])
st.rerun()
else:
query_name = st.text_input(key="query_name",value=selected_query,label="Query Name")
st.session_state[query_name] = saved_queries[selected_query]
helpers = st.session_state.sna.get_helper()
if helpers is not None:
files = []
for helper in helpers:
if not helper.endswith(".json"):
files.append(helper.split("/")[-1])
select_files_cols = st.columns([1,1,1,1,1])
with select_files_cols[0]:
selected_files = st.selectbox("Select Loaded File",files)
if selected_files is not None:
for helper in helpers:
if helper.endswith(selected_files):
helper_json = helpers[helper]
tables = {}
relationships = {}
for object in helper_json:
if helper_json[object]["isTable"]:
tables[helper_json[object]["object_name"]] = helper_json[object]["columns"]
elif helper_json[object]["isRelationship"]:
relationships[helper_json[object]["object_name"]] = {
"source": helper_json[object]["source"],
"target": helper_json[object]["target"]
}
if query_mode == "Raw":
query_cols = st.columns([4,6])
with query_cols[0]:
with st.expander("Objects",expanded=True):
st.write("Tables")
st.write(tables)
st.write("Relationships")
st.write(relationships)
with query_cols[1]:
with st.expander("Query",expanded=True):
if selected_query is None:
query_text = st_monaco(value="", height="200px", language="cypher",theme="vs-dark",lineNumbers=True)
else:
query_text = st_monaco(value=saved_queries[selected_query], height="200px", language="cypher",theme="vs-dark",lineNumbers=True)
if query_mode == "Visual":
query_text = st.session_state[query_name]
node_edge_template = """MATCH
q = (s:##source_node##)-[r##relationship##]->(t:##target_node##)
"""
with st.expander("Node and Edges Selection",expanded=True):
st.write("Select the source and target nodes for the network")
node_edge_selection_columns = st.columns([1,1,1,1,1])
with node_edge_selection_columns[0]:
source_node = st.selectbox("Source Node",list(tables.keys()),index=0,key="source_node")
with node_edge_selection_columns[1]:
target_node = st.selectbox("Target Node",list(tables.keys()),index=1,key="target_node")
if source_node is not None and target_node is not None:
st.write("Select the relationship between the source and target nodes")
node_edge_selection_columns = st.columns([1,1,1,1,1])
with node_edge_selection_columns[0]:
filtered_relationship = {}
for relationship in relationships:
if relationships[relationship]["source"] == source_node and relationships[relationship]["target"] == target_node:
filtered_relationship[relationship] = relationships[relationship]
relationship = st.multiselect("Relationships",list(filtered_relationship.keys()),key="relationship")
if source_node is not None and target_node is not None and len(relationship) > 0:
relationship_str = ""
for rel in relationship:
relationship_str += f":{rel.strip()}|"
relationship_str = relationship_str[:-1]
node_edge_template_stg = node_edge_template
node_edge = node_edge_template_stg.replace("##source_node##",source_node).replace("##target_node##",target_node).replace("##relationship##",relationship_str)
source_primary_key = ""
target_primary_key = ""
for column in tables[source_node]:
if tables[source_node][column] == "INT64":
source_primary_key = column
break
for column in tables[target_node]:
if tables[target_node][column] == "INT64":
target_primary_key = column
break
return_template = f"""RETURN (nodes(q)[1]).{source_primary_key} as source,
(rels(q)[1])._label as relation,
(nodes(q)[2]).{target_primary_key} as target
##return##"""
with st.expander("Returned Columns"):
source_columns = []
target_columns = []
for column in tables[source_node]:
if column != source_primary_key:
source_columns.append(column)
for column in tables[target_node]:
if column != target_primary_key:
target_columns.append(column)
returned_columns = st.columns([3,3,1])
with returned_columns[0]:
source_node_columns = st.multiselect("Source Node Columns",source_columns,key="source_node_columns")
with returned_columns[1]:
target_node_columns = st.multiselect("Target Node Columns",target_columns,key="target_node_columns")
with returned_columns[-1]:
limit = st.number_input("Limit",min_value=1,max_value=10000,value=200,key="limit")
return_str = ""
return_template_stg = return_template
if len(source_node_columns) > 0:
for column in source_node_columns:
return_str += f"\n,(nodes(q)[1]).{column} as source_{column}"
if len(target_node_columns) > 0:
for column in target_node_columns:
return_str += f"\n,(nodes(q)[2]).{column} as target_{column}"
return_statement = return_template_stg.replace("##return##",return_str)
prepared_query = node_edge + return_statement + f"\n LIMIT {limit};"
query_text = prepared_query
query_actions = st.columns([1,1,1,1,1,1,1,1,1,1,1])
with query_actions[0]:
if st.button("Save",key="save_query"):
st.session_state["results"] = None
st.session_state[query_name] = query_text
st.session_state.sna.save_query(query_name,st.session_state[query_name])
with query_actions[1]:
if st.button("Execute",type="primary",key="execute_query"):
st.session_state["results"] = None
st.session_state["results"] = st.session_state.sna.execute_query(query_text)
if "results" in st.session_state:
with st.expander('Results',expanded=False):
result_renderer = get_pyg_renderer(st.session_state["results"],"result")
result_renderer.explorer()
with st.expander('Network',expanded=False):
if st.session_state["results"] is not None:
netwok_columns = st.columns([1,1,1,1])
networkx_result = st.session_state.sna.apply_networkx_analysis(st.session_state["results"],"source","target")
ergm_result = st.session_state.sna.apply_ergm(st.session_state["results"],"source","target")
with netwok_columns[-1]:
analysis_type = st.radio("Analysis Type",["Graph","Visualization"],index=0,key="network_type",horizontal=True)
if analysis_type == "Visualization":
networkx_renderer = get_pyg_renderer(networkx_result,"networkx")
networkx_renderer.explorer()
elif analysis_type == "Graph":
config = Config(width=1000,height=750,directed=True,physics=True,hierarchical=True)
nodes = get_nodes(networkx_result,"source","target")
edges = get_edges(networkx_result,"source","target")
graph_columns = st.columns([9,3])
with graph_columns[0]:
return_value = agraph(nodes=nodes, edges=edges, config=config)
with graph_columns[1]:
st.write("Ergm Analysis")
ergm_result_to_display = ergm_result["ergm"]
del ergm_result_to_display["degree_distribution"]
ergm_result_df = pd.DataFrame.from_dict(ergm_result_to_display,orient="index",columns=["Value"])
st.dataframe(ergm_result_df)
if return_value is not None:
st.write("Networkx Analysis")
with open(f"./output.csv","w") as f:
f.write(networkx_result.to_csv(index=False))
networkx_result_search = networkx_result
# convert source and target to string
networkx_result_search["source"] = networkx_result_search["source"].astype(str).replace(".0","")
networkx_result_search["target"] = networkx_result_search["target"].astype(str).replace(".0","")
# search for a node in the networkx_result_search with return value
search = networkx_result_search[networkx_result_search["source"].str.contains(return_value)]
if search.empty:
search = networkx_result_search[networkx_result_search["target"].str.contains(return_value)]
if not search.empty:
st.write(search)