|
import streamlit as st |
|
import networkx as nx |
|
import plotly.graph_objects as go |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from streamlit_agraph import agraph, Node, Edge, Config |
|
|
|
def plot_compatibility(plants, compatibility_matrix, is_mini=False): |
|
|
|
|
|
G = nx.Graph() |
|
G.add_nodes_from(plants) |
|
for i in range(len(plants)): |
|
for j in range(i + 1, len(plants)): |
|
if compatibility_matrix[i][j] == 0: |
|
G.add_edge(plants[i], plants[j], color='dimgrey') |
|
else: |
|
G.add_edge(plants[i], plants[j], color='green' if compatibility_matrix[i][j] == 1 else 'mediumvioletred') |
|
|
|
|
|
pos = nx.spring_layout(G) |
|
|
|
|
|
node_trace = go.Scatter( |
|
x=[pos[node][0] for node in G.nodes()], |
|
y=[pos[node][1] for node in G.nodes()], |
|
text=list(G.nodes()), |
|
mode='markers+text', |
|
textposition='top center', |
|
hoverinfo='text', |
|
marker=dict( |
|
size=40, |
|
color='lightblue', |
|
line_width=2, |
|
) |
|
) |
|
|
|
|
|
edge_trace = go.Scatter( |
|
x=[], |
|
y=[], |
|
line=dict(width=1, color='dimgrey'), |
|
hoverinfo='none', |
|
mode='lines' |
|
) |
|
|
|
|
|
for edge in G.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_trace['x'] += tuple([x0, x1, None]) |
|
edge_trace['y'] += tuple([y0, y1, None]) |
|
|
|
|
|
edge_traces = [] |
|
edge_legend = set() |
|
for edge in G.edges(data=True): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
color = edge[2]['color'] |
|
trace = go.Scatter( |
|
x=[x0, x1], |
|
y=[y0, y1], |
|
mode='lines', |
|
line=dict(width=2, color=color), |
|
hoverinfo='none' |
|
) |
|
edge_traces.append(trace) |
|
edge_legend.add(color) |
|
|
|
|
|
layout = go.Layout( |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20, l=5, r=5, t=40), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) |
|
) |
|
|
|
|
|
fig = go.Figure(data=[edge_trace, *edge_traces, node_trace], layout=layout) |
|
|
|
|
|
|
|
custom_legend = [] |
|
legend_names = ['Neutral', 'Negative', 'Positive'] |
|
legend_colors = ['dimgrey', 'mediumvioletred', 'green'] |
|
|
|
for name, color in zip(legend_names, legend_colors): |
|
custom_legend.append( |
|
go.Scatter( |
|
x=[None], |
|
y=[None], |
|
mode='markers', |
|
marker=dict(color=color), |
|
name=f'{name}', |
|
showlegend=True, |
|
hoverinfo='none' |
|
) |
|
) |
|
if is_mini == False: |
|
|
|
legend_layout = go.Layout( |
|
title='Plant Compatibility Network Graph', |
|
showlegend=True, |
|
margin=dict(b=1, t=100), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
height=120, |
|
legend=dict( |
|
title='Edge Colors', |
|
orientation='h', |
|
x=-1, |
|
y=1.1, |
|
bgcolor='rgba(0,0,0,0)' |
|
) |
|
) |
|
else: |
|
fig.update_layout( |
|
autosize=False, |
|
width=300, |
|
height=300,) |
|
|
|
|
|
|
|
if is_mini == False: |
|
|
|
legend_fig = go.Figure(data=custom_legend, layout=legend_layout) |
|
|
|
st.plotly_chart(legend_fig, use_container_width=True) |
|
|
|
|
|
|
|
st.plotly_chart(fig) |
|
|
|
|
|
|
|
|
|
def show_plant_tips(): |
|
tips_string = st.session_state.plant_care_tips |
|
|
|
tips_list = tips_string.split("\n") |
|
num_tips = len(tips_list) |
|
st.markdown("## Plant Care Tips for your plants: " + str(st.session_state.input_plants_raw) + "\n\n" + st.session_state.plant_care_tips) |
|
|
|
|
|
|
|
|
|
def visualize_groupings_sankey(): |
|
groupings = st.session_state.grouping |
|
compatibility_matrix = st.session_state.extracted_mat |
|
plant_list = st.session_state.input_plants_raw |
|
|
|
for i, bed_species in enumerate(groupings): |
|
st.subheader(f"Plant Bed {i + 1}") |
|
|
|
|
|
nodes = [] |
|
for species in bed_species: |
|
nodes.append(species) |
|
|
|
|
|
links = [] |
|
for j, species1 in enumerate(bed_species): |
|
for k, species2 in enumerate(bed_species): |
|
if j < k: |
|
species1_index = plant_list.index(species1) |
|
species2_index = plant_list.index(species2) |
|
compatibility = compatibility_matrix[species1_index][species2_index] |
|
|
|
if compatibility == 1: |
|
color = 'green' |
|
elif compatibility == -1: |
|
color = 'pink' |
|
else: |
|
color = 'grey' |
|
|
|
links.append(dict(source=j, target=k, value=compatibility, color=color)) |
|
|
|
|
|
fig = go.Figure(data=[go.Sankey( |
|
node=dict( |
|
label=nodes, |
|
color="lightblue" |
|
), |
|
link=dict( |
|
source=[link['source'] for link in links], |
|
target=[link['target'] for link in links], |
|
value=[link['value'] for link in links], |
|
color=[link['color'] for link in links] |
|
) |
|
)]) |
|
|
|
|
|
layout = go.Layout( |
|
plot_bgcolor='black', |
|
paper_bgcolor='black', |
|
title_font=dict(color='white') |
|
) |
|
|
|
|
|
fig.update_layout(layout) |
|
|
|
|
|
st.plotly_chart(fig) |
|
|
|
|
|
def visualize_groupings(): |
|
groupings = st.session_state.grouping |
|
compatibility_matrix = st.session_state.extracted_mat |
|
plant_list = st.session_state.input_plants_raw |
|
|
|
def generate_grouping_matrices(groupings, compatibility_matrix, plant_list): |
|
grouping_matrices = [] |
|
for grouping in groupings: |
|
indices = [plant_list.index(plant) for plant in grouping] |
|
submatrix = [[compatibility_matrix[i][j] for j in indices] for i in indices] |
|
grouping_matrices.append(submatrix) |
|
return grouping_matrices |
|
|
|
grouping_matrices = generate_grouping_matrices(groupings, compatibility_matrix, plant_list) |
|
for i, submatrix in enumerate(grouping_matrices): |
|
col1, col2= st.columns([1,3]) |
|
with col1: |
|
st.write(f"Plant Bed {i + 1}") |
|
st.write("Plant List") |
|
st.write(groupings[i]) |
|
with col2: |
|
plot_compatibility_with_agraph(groupings[i], st.session_state.full_mat, is_mini=True) |
|
|
|
|
|
|
|
def plot_compatibility_with_agraph(plants, compatibility_matrix, is_mini=False): |
|
|
|
nodes = [] |
|
edges = [] |
|
|
|
|
|
def get_image_url(plant_name): |
|
index = st.session_state.plant_list.index(plant_name) |
|
image_path = f"https://github.com/4dh/GRDN/blob/dev/src/assets/plant_images/plant_{index}.png?raw=true" |
|
print(image_path) |
|
return image_path |
|
|
|
size_n = 32 if not is_mini else 24 |
|
|
|
for plant in plants: |
|
nodes.append(Node(id=plant, |
|
label=plant, |
|
|
|
font={'size': 20}, |
|
|
|
scaling={'label': {'enabled': True}}, |
|
size=size_n, |
|
shape="circularImage", |
|
image=get_image_url(plant))) |
|
|
|
|
|
|
|
|
|
for i, i_p in enumerate(st.session_state.plant_list): |
|
for j, j_p in enumerate(st.session_state.plant_list): |
|
if i != j: |
|
|
|
|
|
if is_mini == False: |
|
length_e = 300 |
|
else: |
|
length_e = 150 |
|
|
|
if i_p in st.session_state.input_plants_raw and j_p in st.session_state.input_plants_raw: |
|
|
|
if compatibility_matrix[i][j] == 1: |
|
color = 'green' |
|
edges.append(Edge(source=i_p, target=j_p,width = 3.5, type="CURVE_SMOOTH", color=color, length=length_e)) |
|
print(i,j,i_p,j_p,color) |
|
elif compatibility_matrix[i][j] == -1: |
|
color = 'mediumvioletred' |
|
edges.append(Edge(source=i_p, target=j_p,width = 3.5, type="CURVE_SMOOTH", color=color, length=length_e)) |
|
print(i,j,i_p,j_p,color) |
|
|
|
else: |
|
color = 'dimgrey' |
|
edges.append(Edge(source=i_p, target=j_p,width = .2, type="CURVE_SMOOTH", color=color, length=length_e)) |
|
print(i,j,i_p,j_p,color) |
|
|
|
|
|
|
|
|
|
config = Config(width=650 if not is_mini else 400, |
|
height=400 if not is_mini else 400, |
|
directed=False, |
|
physics=True, |
|
hierarchical=False, |
|
nodeHighlightBehavior=True, |
|
highlightColor="#F7A7A6", |
|
collapsible=True, |
|
maxZoom=5, |
|
minZoom=0.2, |
|
initialZoom=4, |
|
) |
|
|
|
|
|
|
|
if not is_mini: |
|
|
|
custom_legend = [] |
|
legend_names = ['Neutral', 'Negative', 'Positive'] |
|
legend_colors = ['dimgrey', 'mediumvioletred', 'green'] |
|
|
|
for name, color in zip(legend_names, legend_colors): |
|
custom_legend.append( |
|
go.Scatter( |
|
x=[None], |
|
y=[None], |
|
mode='markers', |
|
marker=dict(color=color), |
|
name=name, |
|
showlegend=True, |
|
hoverinfo='none' |
|
) |
|
) |
|
|
|
|
|
legend_layout = go.Layout( |
|
title='Plant Compatibility Network Graph', |
|
showlegend=True, |
|
margin=dict(b=1, t=100), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
height=120, |
|
legend=dict( |
|
title='Edge Colors', |
|
orientation='h', |
|
|
|
x=-1, |
|
y=1.1, |
|
bgcolor='rgba(0,0,0,0)' |
|
) |
|
) |
|
|
|
|
|
legend_fig = go.Figure(data=custom_legend, layout=legend_layout) |
|
|
|
|
|
st.plotly_chart(legend_fig, use_container_width=True) |
|
|
|
|
|
|
|
return_value = agraph(nodes=nodes, |
|
edges=edges, |
|
config=config) |
|
|
|
|