danidanidani commited on
Commit
31692d5
1 Parent(s): 64558b1

Update src/frontend/visualizations.py

Browse files
Files changed (1) hide show
  1. src/frontend/visualizations.py +150 -116
src/frontend/visualizations.py CHANGED
@@ -5,17 +5,23 @@ import matplotlib.pyplot as plt
5
  import numpy as np
6
  from streamlit_agraph import agraph, Node, Edge, Config
7
 
8
- def plot_compatibility(plants, compatibility_matrix, is_mini=False):
9
 
 
10
  # Create the graph
11
  G = nx.Graph()
12
  G.add_nodes_from(plants)
13
  for i in range(len(plants)):
14
  for j in range(i + 1, len(plants)):
15
  if compatibility_matrix[i][j] == 0:
16
- G.add_edge(plants[i], plants[j], color='dimgrey')
17
  else:
18
- G.add_edge(plants[i], plants[j], color='green' if compatibility_matrix[i][j] == 1 else 'mediumvioletred')
 
 
 
 
 
 
19
 
20
  # Generate positions for the nodes
21
  pos = nx.spring_layout(G)
@@ -25,31 +31,27 @@ def plot_compatibility(plants, compatibility_matrix, is_mini=False):
25
  x=[pos[node][0] for node in G.nodes()],
26
  y=[pos[node][1] for node in G.nodes()],
27
  text=list(G.nodes()),
28
- mode='markers+text',
29
- textposition='top center',
30
- hoverinfo='text',
31
  marker=dict(
32
  size=40,
33
- color='lightblue',
34
  line_width=2,
35
- )
36
  )
37
 
38
  # Create edge trace
39
  edge_trace = go.Scatter(
40
- x=[],
41
- y=[],
42
- line=dict(width=1, color='dimgrey'),
43
- hoverinfo='none',
44
- mode='lines'
45
  )
46
 
47
  # Add coordinates to edge trace
48
  for edge in G.edges():
49
  x0, y0 = pos[edge[0]]
50
  x1, y1 = pos[edge[1]]
51
- edge_trace['x'] += tuple([x0, x1, None])
52
- edge_trace['y'] += tuple([y0, y1, None])
53
 
54
  # Create edge traces for colored edges
55
  edge_traces = []
@@ -57,13 +59,13 @@ def plot_compatibility(plants, compatibility_matrix, is_mini=False):
57
  for edge in G.edges(data=True):
58
  x0, y0 = pos[edge[0]]
59
  x1, y1 = pos[edge[1]]
60
- color = edge[2]['color']
61
  trace = go.Scatter(
62
  x=[x0, x1],
63
  y=[y0, y1],
64
- mode='lines',
65
  line=dict(width=2, color=color),
66
- hoverinfo='none'
67
  )
68
  edge_traces.append(trace)
69
  edge_legend.add(color) # Add edge color to the set
@@ -71,57 +73,55 @@ def plot_compatibility(plants, compatibility_matrix, is_mini=False):
71
  # Create layout
72
  layout = go.Layout(
73
  showlegend=False,
74
- hovermode='closest',
75
  margin=dict(b=20, l=5, r=5, t=40),
76
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
77
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
78
  )
79
 
80
  # Create figure
81
  fig = go.Figure(data=[edge_trace, *edge_traces, node_trace], layout=layout)
82
 
83
-
84
  # Create custom legend for edge colors
85
  custom_legend = []
86
- legend_names = ['Neutral', 'Negative', 'Positive']
87
- legend_colors = ['dimgrey', 'mediumvioletred', 'green']
88
 
89
  for name, color in zip(legend_names, legend_colors):
90
  custom_legend.append(
91
  go.Scatter(
92
  x=[None],
93
  y=[None],
94
- mode='markers',
95
  marker=dict(color=color),
96
- name=f'{name}',
97
  showlegend=True,
98
- hoverinfo='none'
99
  )
100
  )
101
  if is_mini == False:
102
  # Create layout for custom legend figure
103
  legend_layout = go.Layout(
104
- title='Plant Compatibility Network Graph',
105
  showlegend=True,
106
  margin=dict(b=1, t=100),
107
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
108
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
109
  height=120,
110
  legend=dict(
111
- title='Edge Colors',
112
- orientation='h',
113
  x=-1,
114
  y=1.1,
115
- bgcolor='rgba(0,0,0,0)'
116
- )
117
  )
118
  else:
119
  fig.update_layout(
120
- autosize=False,
121
- width=300,
122
- height=300,)
123
-
124
-
125
 
126
  if is_mini == False:
127
  # Create figure for custom legend
@@ -129,21 +129,22 @@ def plot_compatibility(plants, compatibility_matrix, is_mini=False):
129
  # Render the custom legend using Plotly in Streamlit
130
  st.plotly_chart(legend_fig, use_container_width=True)
131
 
132
-
133
  # Render the graph using Plotly in Streamlit
134
  st.plotly_chart(fig)
135
 
136
 
137
-
138
  # this is not used as it needs to be refactored and is not working as intended
139
  def show_plant_tips():
140
  tips_string = st.session_state.plant_care_tips
141
 
142
  tips_list = tips_string.split("\n")
143
  num_tips = len(tips_list)
144
- st.markdown("## Plant Care Tips for your plants: " + str(st.session_state.input_plants_raw) + "\n\n" + st.session_state.plant_care_tips)
145
-
146
-
 
 
 
147
 
148
 
149
  def visualize_groupings_sankey():
@@ -169,33 +170,34 @@ def visualize_groupings_sankey():
169
  compatibility = compatibility_matrix[species1_index][species2_index]
170
 
171
  if compatibility == 1:
172
- color = 'green'
173
  elif compatibility == -1:
174
- color = 'pink'
175
  else:
176
- color = 'grey'
177
 
178
- links.append(dict(source=j, target=k, value=compatibility, color=color))
 
 
179
 
180
  # Create the Sankey diagram
181
- fig = go.Figure(data=[go.Sankey(
182
- node=dict(
183
- label=nodes,
184
- color="lightblue"
185
- ),
186
- link=dict(
187
- source=[link['source'] for link in links],
188
- target=[link['target'] for link in links],
189
- value=[link['value'] for link in links],
190
- color=[link['color'] for link in links]
191
- )
192
- )])
 
193
 
194
  # Set the layout properties
195
  layout = go.Layout(
196
- plot_bgcolor='black',
197
- paper_bgcolor='black',
198
- title_font=dict(color='white')
199
  )
200
 
201
  # Set the figure layout
@@ -209,7 +211,7 @@ def visualize_groupings():
209
  groupings = st.session_state.grouping
210
  compatibility_matrix = st.session_state.extracted_mat
211
  plant_list = st.session_state.input_plants_raw
212
-
213
  def generate_grouping_matrices(groupings, compatibility_matrix, plant_list):
214
  grouping_matrices = []
215
  for grouping in groupings:
@@ -217,17 +219,20 @@ def visualize_groupings():
217
  submatrix = [[compatibility_matrix[i][j] for j in indices] for i in indices]
218
  grouping_matrices.append(submatrix)
219
  return grouping_matrices
220
-
221
- grouping_matrices = generate_grouping_matrices(groupings, compatibility_matrix, plant_list)
 
 
222
  for i, submatrix in enumerate(grouping_matrices):
223
- col1, col2= st.columns([1,3])
224
- with col1:
225
  st.write(f"Plant Bed {i + 1}")
226
  st.write("Plant List")
227
  st.write(groupings[i])
228
  with col2:
229
- plot_compatibility_with_agraph(groupings[i], st.session_state.full_mat, is_mini=True)
230
-
 
231
 
232
 
233
  def plot_compatibility_with_agraph(plants, compatibility_matrix, is_mini=False):
@@ -245,18 +250,22 @@ def plot_compatibility_with_agraph(plants, compatibility_matrix, is_mini=False):
245
  size_n = 32 if not is_mini else 24
246
  # Create nodes with images
247
  for plant in plants:
248
- nodes.append(Node(id=plant,
249
- label=plant,
250
- # make text bigger
251
- font={'size': 20},
252
- # spread nodes out
253
- scaling={'label': {'enabled': True}},
254
- size=size_n,
255
- shape="circularImage",
256
- image=get_image_url(plant)))
 
 
 
 
257
 
258
  # Create edges based on compatibility
259
- #for i in range(len(st.session_state.plant_list)):
260
  # loop through all plants in raw long list and find the index of the plant in the plant list to get relevant metadata. skip if we are looking at the same plant
261
  for i, i_p in enumerate(st.session_state.plant_list):
262
  for j, j_p in enumerate(st.session_state.plant_list):
@@ -268,75 +277,103 @@ def plot_compatibility_with_agraph(plants, compatibility_matrix, is_mini=False):
268
  else:
269
  length_e = 150
270
 
271
- if i_p in st.session_state.input_plants_raw and j_p in st.session_state.input_plants_raw:
 
 
 
272
  # use the compatibility matrix and the plant to index mapping to determine the color of the edge
273
  if compatibility_matrix[i][j] == 1:
274
- color = 'green'
275
- edges.append(Edge(source=i_p, target=j_p,width = 3.5, type="CURVE_SMOOTH", color=color, length=length_e))
276
- print(i,j,i_p,j_p,color)
 
 
 
 
 
 
 
 
 
277
  elif compatibility_matrix[i][j] == -1:
278
- color = 'mediumvioletred'
279
- edges.append(Edge(source=i_p, target=j_p,width = 3.5, type="CURVE_SMOOTH", color=color, length=length_e))
280
- print(i,j,i_p,j_p,color)
 
 
 
 
 
 
 
 
 
281
 
282
  else:
283
- color = 'dimgrey'
284
- edges.append(Edge(source=i_p, target=j_p,width = .2, type="CURVE_SMOOTH", color=color, length=length_e))
285
- print(i,j,i_p,j_p,color)
 
 
 
 
 
 
 
 
 
286
 
287
-
288
-
289
  # Configuration for the graph
290
- config = Config(width=650 if not is_mini else 400,
291
- height=400 if not is_mini else 400,
292
- directed=False,
293
- physics=True,
294
- hierarchical=False,
295
- nodeHighlightBehavior=True,
296
- highlightColor="#F7A7A6",
297
- collapsible=True,
298
- maxZoom=5,
299
- minZoom=0.2,
300
- initialZoom=4,
301
- )
302
-
303
 
304
  # Handling for non-mini version
305
  if not is_mini:
306
  # Create custom legend for edge colors at the top of the page
307
  custom_legend = []
308
- legend_names = ['Neutral', 'Negative', 'Positive']
309
- legend_colors = ['dimgrey', 'mediumvioletred', 'green']
310
 
311
  for name, color in zip(legend_names, legend_colors):
312
  custom_legend.append(
313
  go.Scatter(
314
  x=[None],
315
  y=[None],
316
- mode='markers',
317
  marker=dict(color=color),
318
  name=name,
319
  showlegend=True,
320
- hoverinfo='none'
321
  )
322
  )
323
 
324
  # Create layout for custom legend figure
325
  legend_layout = go.Layout(
326
- title='Plant Compatibility Network Graph',
327
  showlegend=True,
328
  margin=dict(b=1, t=100),
329
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
330
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
331
  height=120,
332
  legend=dict(
333
- title='Edge Colors',
334
- orientation='h',
335
  # make it appear above the graph
336
  x=-1,
337
  y=1.1,
338
- bgcolor='rgba(0,0,0,0)'
339
- )
340
  )
341
 
342
  # Create figure for custom legend
@@ -345,10 +382,7 @@ def plot_compatibility_with_agraph(plants, compatibility_matrix, is_mini=False):
345
  # Render the custom legend using Plotly in Streamlit
346
  st.plotly_chart(legend_fig, use_container_width=True)
347
 
348
-
349
  # Render the graph using streamlit-agraph
350
- return_value = agraph(nodes=nodes,
351
- edges=edges,
352
- config=config)
353
-
354
 
 
5
  import numpy as np
6
  from streamlit_agraph import agraph, Node, Edge, Config
7
 
 
8
 
9
+ def plot_compatibility(plants, compatibility_matrix, is_mini=False):
10
  # Create the graph
11
  G = nx.Graph()
12
  G.add_nodes_from(plants)
13
  for i in range(len(plants)):
14
  for j in range(i + 1, len(plants)):
15
  if compatibility_matrix[i][j] == 0:
16
+ G.add_edge(plants[i], plants[j], color="dimgrey")
17
  else:
18
+ G.add_edge(
19
+ plants[i],
20
+ plants[j],
21
+ color="green"
22
+ if compatibility_matrix[i][j] == 1
23
+ else "mediumvioletred",
24
+ )
25
 
26
  # Generate positions for the nodes
27
  pos = nx.spring_layout(G)
 
31
  x=[pos[node][0] for node in G.nodes()],
32
  y=[pos[node][1] for node in G.nodes()],
33
  text=list(G.nodes()),
34
+ mode="markers+text",
35
+ textposition="top center",
36
+ hoverinfo="text",
37
  marker=dict(
38
  size=40,
39
+ color="lightblue",
40
  line_width=2,
41
+ ),
42
  )
43
 
44
  # Create edge trace
45
  edge_trace = go.Scatter(
46
+ x=[], y=[], line=dict(width=1, color="dimgrey"), hoverinfo="none", mode="lines"
 
 
 
 
47
  )
48
 
49
  # Add coordinates to edge trace
50
  for edge in G.edges():
51
  x0, y0 = pos[edge[0]]
52
  x1, y1 = pos[edge[1]]
53
+ edge_trace["x"] += tuple([x0, x1, None])
54
+ edge_trace["y"] += tuple([y0, y1, None])
55
 
56
  # Create edge traces for colored edges
57
  edge_traces = []
 
59
  for edge in G.edges(data=True):
60
  x0, y0 = pos[edge[0]]
61
  x1, y1 = pos[edge[1]]
62
+ color = edge[2]["color"]
63
  trace = go.Scatter(
64
  x=[x0, x1],
65
  y=[y0, y1],
66
+ mode="lines",
67
  line=dict(width=2, color=color),
68
+ hoverinfo="none",
69
  )
70
  edge_traces.append(trace)
71
  edge_legend.add(color) # Add edge color to the set
 
73
  # Create layout
74
  layout = go.Layout(
75
  showlegend=False,
76
+ hovermode="closest",
77
  margin=dict(b=20, l=5, r=5, t=40),
78
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
79
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
80
  )
81
 
82
  # Create figure
83
  fig = go.Figure(data=[edge_trace, *edge_traces, node_trace], layout=layout)
84
 
 
85
  # Create custom legend for edge colors
86
  custom_legend = []
87
+ legend_names = ["Neutral", "Negative", "Positive"]
88
+ legend_colors = ["dimgrey", "mediumvioletred", "green"]
89
 
90
  for name, color in zip(legend_names, legend_colors):
91
  custom_legend.append(
92
  go.Scatter(
93
  x=[None],
94
  y=[None],
95
+ mode="markers",
96
  marker=dict(color=color),
97
+ name=f"{name}",
98
  showlegend=True,
99
+ hoverinfo="none",
100
  )
101
  )
102
  if is_mini == False:
103
  # Create layout for custom legend figure
104
  legend_layout = go.Layout(
105
+ title="Plant Compatibility Network Graph",
106
  showlegend=True,
107
  margin=dict(b=1, t=100),
108
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
109
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
110
  height=120,
111
  legend=dict(
112
+ title="Edge Colors",
113
+ orientation="h",
114
  x=-1,
115
  y=1.1,
116
+ bgcolor="rgba(0,0,0,0)",
117
+ ),
118
  )
119
  else:
120
  fig.update_layout(
121
+ autosize=False,
122
+ width=300,
123
+ height=300,
124
+ )
 
125
 
126
  if is_mini == False:
127
  # Create figure for custom legend
 
129
  # Render the custom legend using Plotly in Streamlit
130
  st.plotly_chart(legend_fig, use_container_width=True)
131
 
 
132
  # Render the graph using Plotly in Streamlit
133
  st.plotly_chart(fig)
134
 
135
 
 
136
  # this is not used as it needs to be refactored and is not working as intended
137
  def show_plant_tips():
138
  tips_string = st.session_state.plant_care_tips
139
 
140
  tips_list = tips_string.split("\n")
141
  num_tips = len(tips_list)
142
+ st.markdown(
143
+ "## Plant Care Tips for your plants: "
144
+ + str(st.session_state.input_plants_raw)
145
+ + "\n\n"
146
+ + st.session_state.plant_care_tips
147
+ )
148
 
149
 
150
  def visualize_groupings_sankey():
 
170
  compatibility = compatibility_matrix[species1_index][species2_index]
171
 
172
  if compatibility == 1:
173
+ color = "green"
174
  elif compatibility == -1:
175
+ color = "pink"
176
  else:
177
+ color = "grey"
178
 
179
+ links.append(
180
+ dict(source=j, target=k, value=compatibility, color=color)
181
+ )
182
 
183
  # Create the Sankey diagram
184
+ fig = go.Figure(
185
+ data=[
186
+ go.Sankey(
187
+ node=dict(label=nodes, color="lightblue"),
188
+ link=dict(
189
+ source=[link["source"] for link in links],
190
+ target=[link["target"] for link in links],
191
+ value=[link["value"] for link in links],
192
+ color=[link["color"] for link in links],
193
+ ),
194
+ )
195
+ ]
196
+ )
197
 
198
  # Set the layout properties
199
  layout = go.Layout(
200
+ plot_bgcolor="black", paper_bgcolor="black", title_font=dict(color="white")
 
 
201
  )
202
 
203
  # Set the figure layout
 
211
  groupings = st.session_state.grouping
212
  compatibility_matrix = st.session_state.extracted_mat
213
  plant_list = st.session_state.input_plants_raw
214
+
215
  def generate_grouping_matrices(groupings, compatibility_matrix, plant_list):
216
  grouping_matrices = []
217
  for grouping in groupings:
 
219
  submatrix = [[compatibility_matrix[i][j] for j in indices] for i in indices]
220
  grouping_matrices.append(submatrix)
221
  return grouping_matrices
222
+
223
+ grouping_matrices = generate_grouping_matrices(
224
+ groupings, compatibility_matrix, plant_list
225
+ )
226
  for i, submatrix in enumerate(grouping_matrices):
227
+ col1, col2 = st.columns([1, 3])
228
+ with col1:
229
  st.write(f"Plant Bed {i + 1}")
230
  st.write("Plant List")
231
  st.write(groupings[i])
232
  with col2:
233
+ plot_compatibility_with_agraph(
234
+ groupings[i], st.session_state.full_mat, is_mini=True
235
+ )
236
 
237
 
238
  def plot_compatibility_with_agraph(plants, compatibility_matrix, is_mini=False):
 
250
  size_n = 32 if not is_mini else 24
251
  # Create nodes with images
252
  for plant in plants:
253
+ nodes.append(
254
+ Node(
255
+ id=plant,
256
+ label=plant,
257
+ # make text bigger
258
+ font={"size": 20},
259
+ # spread nodes out
260
+ scaling={"label": {"enabled": True}},
261
+ size=size_n,
262
+ shape="circularImage",
263
+ image=get_image_url(plant),
264
+ )
265
+ )
266
 
267
  # Create edges based on compatibility
268
+ # for i in range(len(st.session_state.plant_list)):
269
  # loop through all plants in raw long list and find the index of the plant in the plant list to get relevant metadata. skip if we are looking at the same plant
270
  for i, i_p in enumerate(st.session_state.plant_list):
271
  for j, j_p in enumerate(st.session_state.plant_list):
 
277
  else:
278
  length_e = 150
279
 
280
+ if (
281
+ i_p in st.session_state.input_plants_raw
282
+ and j_p in st.session_state.input_plants_raw
283
+ ):
284
  # use the compatibility matrix and the plant to index mapping to determine the color of the edge
285
  if compatibility_matrix[i][j] == 1:
286
+ color = "green"
287
+ edges.append(
288
+ Edge(
289
+ source=i_p,
290
+ target=j_p,
291
+ width=3.5,
292
+ type="CURVE_SMOOTH",
293
+ color=color,
294
+ length=length_e,
295
+ )
296
+ )
297
+ print(i, j, i_p, j_p, color)
298
  elif compatibility_matrix[i][j] == -1:
299
+ color = "mediumvioletred"
300
+ edges.append(
301
+ Edge(
302
+ source=i_p,
303
+ target=j_p,
304
+ width=3.5,
305
+ type="CURVE_SMOOTH",
306
+ color=color,
307
+ length=length_e,
308
+ )
309
+ )
310
+ print(i, j, i_p, j_p, color)
311
 
312
  else:
313
+ color = "dimgrey"
314
+ edges.append(
315
+ Edge(
316
+ source=i_p,
317
+ target=j_p,
318
+ width=0.2,
319
+ type="CURVE_SMOOTH",
320
+ color=color,
321
+ length=length_e,
322
+ )
323
+ )
324
+ print(i, j, i_p, j_p, color)
325
 
 
 
326
  # Configuration for the graph
327
+ config = Config(
328
+ width=650 if not is_mini else 400,
329
+ height=400 if not is_mini else 400,
330
+ directed=False,
331
+ physics=True,
332
+ hierarchical=False,
333
+ nodeHighlightBehavior=True,
334
+ highlightColor="#F7A7A6",
335
+ collapsible=True,
336
+ maxZoom=5,
337
+ minZoom=0.2,
338
+ initialZoom=4,
339
+ )
340
 
341
  # Handling for non-mini version
342
  if not is_mini:
343
  # Create custom legend for edge colors at the top of the page
344
  custom_legend = []
345
+ legend_names = ["Neutral", "Negative", "Positive"]
346
+ legend_colors = ["dimgrey", "mediumvioletred", "green"]
347
 
348
  for name, color in zip(legend_names, legend_colors):
349
  custom_legend.append(
350
  go.Scatter(
351
  x=[None],
352
  y=[None],
353
+ mode="markers",
354
  marker=dict(color=color),
355
  name=name,
356
  showlegend=True,
357
+ hoverinfo="none",
358
  )
359
  )
360
 
361
  # Create layout for custom legend figure
362
  legend_layout = go.Layout(
363
+ title="Plant Compatibility Network Graph",
364
  showlegend=True,
365
  margin=dict(b=1, t=100),
366
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
367
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
368
  height=120,
369
  legend=dict(
370
+ title="Edge Colors",
371
+ orientation="h",
372
  # make it appear above the graph
373
  x=-1,
374
  y=1.1,
375
+ bgcolor="rgba(0,0,0,0)",
376
+ ),
377
  )
378
 
379
  # Create figure for custom legend
 
382
  # Render the custom legend using Plotly in Streamlit
383
  st.plotly_chart(legend_fig, use_container_width=True)
384
 
 
385
  # Render the graph using streamlit-agraph
386
+ return_value = agraph(nodes=nodes, edges=edges, config=config)
387
+
 
 
388