import streamlit as st import streamlit_antd_components as sac import pandas as pd from streamlit_monaco import st_monaco import os from pygwalker.api.streamlit import StreamlitRenderer import networkx as nx from streamlit_agraph import agraph, Node, Edge, Config import numpy as np def get_pyg_renderer(df,spec) -> StreamlitRenderer: if os.path.exists(f"./{spec}_config.json"): os.remove(f"./{spec}_config.json") return StreamlitRenderer(df, spec=f"./{spec}_config.json", spec_io_mode="rw") def get_nodes(df:pd.DataFrame,source,target): nodes = [] unique_node_id = [] df.replace(np.nan, None, regex=True,inplace=True) # remove rows with null source or target df = df.dropna(subset=[source,target]) for index, row in df.iterrows(): cols = df.columns if row[source] is not None or row[source] != "" or row[source] != np.nan: meta_data = {} for col in cols: if col != source and col != target and col != "relation": if row[col] is not None: meta_data[col] = row[col] if str(row[source]).replace(".0","") not in unique_node_id: unique_node_id.append(str(row[source]).replace(".0","")) nodes.append(Node( id=str(row[source]).replace(".0",""), label=str(row[source]).replace(".0",""), size=10, color="red", )) if row[target] is not None or row[target] != "" or row[target] != np.nan: meta_data = {} for col in cols: if col != source and col != target and col != "relation": if row[col] is not None: meta_data[col] = row[col] if str(row[target]).replace(".0","") not in unique_node_id: unique_node_id.append(str(row[target]).replace(".0","")) nodes.append(Node( id=str(row[target]).replace(".0",""), label=str(row[target]).replace(".0",""), size=10, color="red", )) return nodes def get_edges(df:pd.DataFrame,source,target): edges = [] df.replace(np.nan, None, regex=True,inplace=True) for index, row in df.iterrows(): if str(row[source]).replace(".0","") is not None and str(row[target]).replace(".0","") is not None: edges.append(Edge( source=str(row[source]).replace(".0",""), target=str(row[target]).replace(".0",""), label=row["relation"], size=1, )) return edges def analyze(): st.title('Analyze') col = st.columns([4, 1]) with col[0]: st.subheader("Analyze loaded data and visualize it.") with col[1]: query_mode = st.radio("Query Mode",["Visual","Raw"],horizontal=True,index=0,key="query_type") query_controls = st.columns([1,1,1,1,1]) with query_controls[0]: saved_queries = st.session_state.sna.get_queries() if saved_queries is None: saved_queries = {} selected_query = st.selectbox("Select Query",saved_queries.keys(),index=None,key="selected_query") with query_controls[1]: if selected_query is None: query_name = st.text_input(key="query_name",value="New Query",label="Query Name") if query_name not in st.session_state: st.session_state[query_name] = "MATCH (n) RETURN n LIMIT 25" st.session_state.sna.save_query(query_name,st.session_state[query_name]) st.rerun() else: query_name = st.text_input(key="query_name",value=selected_query,label="Query Name") st.session_state[query_name] = saved_queries[selected_query] helpers = st.session_state.sna.get_helper() if helpers is not None: files = [] for helper in helpers: if not helper.endswith(".json"): files.append(helper.split("/")[-1]) select_files_cols = st.columns([1,1,1,1,1]) with select_files_cols[0]: selected_files = st.selectbox("Select Loaded File",files) if selected_files is not None: for helper in helpers: if helper.endswith(selected_files): helper_json = helpers[helper] tables = {} relationships = {} for object in helper_json: if helper_json[object]["isTable"]: tables[helper_json[object]["object_name"]] = helper_json[object]["columns"] elif helper_json[object]["isRelationship"]: relationships[helper_json[object]["object_name"]] = { "source": helper_json[object]["source"], "target": helper_json[object]["target"] } if query_mode == "Raw": query_cols = st.columns([4,6]) with query_cols[0]: with st.expander("Objects",expanded=True): st.write("Tables") st.write(tables) st.write("Relationships") st.write(relationships) with query_cols[1]: with st.expander("Query",expanded=True): if selected_query is None: query_text = st_monaco(value="", height="200px", language="cypher",theme="vs-dark",lineNumbers=True) else: query_text = st_monaco(value=saved_queries[selected_query], height="200px", language="cypher",theme="vs-dark",lineNumbers=True) if query_mode == "Visual": query_text = st.session_state[query_name] node_edge_template = """MATCH q = (s:##source_node##)-[r##relationship##]->(t:##target_node##) """ with st.expander("Node and Edges Selection",expanded=True): st.write("Select the source and target nodes for the network") node_edge_selection_columns = st.columns([1,1,1,1,1]) with node_edge_selection_columns[0]: source_node = st.selectbox("Source Node",list(tables.keys()),index=0,key="source_node") with node_edge_selection_columns[1]: target_node = st.selectbox("Target Node",list(tables.keys()),index=1,key="target_node") if source_node is not None and target_node is not None: st.write("Select the relationship between the source and target nodes") node_edge_selection_columns = st.columns([1,1,1,1,1]) with node_edge_selection_columns[0]: filtered_relationship = {} for relationship in relationships: if relationships[relationship]["source"] == source_node and relationships[relationship]["target"] == target_node: filtered_relationship[relationship] = relationships[relationship] relationship = st.multiselect("Relationships",list(filtered_relationship.keys()),key="relationship") if source_node is not None and target_node is not None and len(relationship) > 0: relationship_str = "" for rel in relationship: relationship_str += f":{rel.strip()}|" relationship_str = relationship_str[:-1] node_edge_template_stg = node_edge_template node_edge = node_edge_template_stg.replace("##source_node##",source_node).replace("##target_node##",target_node).replace("##relationship##",relationship_str) source_primary_key = "" target_primary_key = "" for column in tables[source_node]: if tables[source_node][column] == "INT64": source_primary_key = column break for column in tables[target_node]: if tables[target_node][column] == "INT64": target_primary_key = column break return_template = f"""RETURN (nodes(q)[1]).{source_primary_key} as source, (rels(q)[1])._label as relation, (nodes(q)[2]).{target_primary_key} as target ##return##""" with st.expander("Returned Columns"): source_columns = [] target_columns = [] for column in tables[source_node]: if column != source_primary_key: source_columns.append(column) for column in tables[target_node]: if column != target_primary_key: target_columns.append(column) returned_columns = st.columns([3,3,1]) with returned_columns[0]: source_node_columns = st.multiselect("Source Node Columns",source_columns,key="source_node_columns") with returned_columns[1]: target_node_columns = st.multiselect("Target Node Columns",target_columns,key="target_node_columns") with returned_columns[-1]: limit = st.number_input("Limit",min_value=1,max_value=10000,value=200,key="limit") return_str = "" return_template_stg = return_template if len(source_node_columns) > 0: for column in source_node_columns: return_str += f"\n,(nodes(q)[1]).{column} as source_{column}" if len(target_node_columns) > 0: for column in target_node_columns: return_str += f"\n,(nodes(q)[2]).{column} as target_{column}" return_statement = return_template_stg.replace("##return##",return_str) prepared_query = node_edge + return_statement + f"\n LIMIT {limit};" query_text = prepared_query query_actions = st.columns([1,1,1,1,1,1,1,1,1,1,1]) with query_actions[0]: if st.button("Save",key="save_query"): st.session_state["results"] = None st.session_state[query_name] = query_text st.session_state.sna.save_query(query_name,st.session_state[query_name]) with query_actions[1]: if st.button("Execute",type="primary",key="execute_query"): st.session_state["results"] = None st.session_state["results"] = st.session_state.sna.execute_query(query_text) if "results" in st.session_state: with st.expander('Results',expanded=False): result_renderer = get_pyg_renderer(st.session_state["results"],"result") result_renderer.explorer() with st.expander('Network',expanded=False): if st.session_state["results"] is not None: netwok_columns = st.columns([1,1,1,1]) networkx_result = st.session_state.sna.apply_networkx_analysis(st.session_state["results"],"source","target") ergm_result = st.session_state.sna.apply_ergm(st.session_state["results"],"source","target") with netwok_columns[-1]: analysis_type = st.radio("Analysis Type",["Graph","Visualization"],index=0,key="network_type",horizontal=True) if analysis_type == "Visualization": networkx_renderer = get_pyg_renderer(networkx_result,"networkx") networkx_renderer.explorer() elif analysis_type == "Graph": config = Config(width=1000,height=750,directed=True,physics=True,hierarchical=True) nodes = get_nodes(networkx_result,"source","target") edges = get_edges(networkx_result,"source","target") graph_columns = st.columns([9,3]) with graph_columns[0]: return_value = agraph(nodes=nodes, edges=edges, config=config) with graph_columns[1]: st.write("Ergm Analysis") ergm_result_to_display = ergm_result["ergm"] del ergm_result_to_display["degree_distribution"] ergm_result_df = pd.DataFrame.from_dict(ergm_result_to_display,orient="index",columns=["Value"]) st.dataframe(ergm_result_df) if return_value is not None: st.write("Networkx Analysis") with open(f"./output.csv","w") as f: f.write(networkx_result.to_csv(index=False)) networkx_result_search = networkx_result # convert source and target to string networkx_result_search["source"] = networkx_result_search["source"].astype(str).replace(".0","") networkx_result_search["target"] = networkx_result_search["target"].astype(str).replace(".0","") # search for a node in the networkx_result_search with return value search = networkx_result_search[networkx_result_search["source"].str.contains(return_value)] if search.empty: search = networkx_result_search[networkx_result_search["target"].str.contains(return_value)] if not search.empty: st.write(search)