rjadr commited on
Commit
d8cce8d
1 Parent(s): 92d6900

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -3
app.py CHANGED
@@ -13,6 +13,11 @@ from pandas.api.types import (
13
  )
14
  import subprocess
15
  from tempfile import NamedTemporaryFile
 
 
 
 
 
16
 
17
  st.set_page_config(layout="wide")
18
 
@@ -38,6 +43,11 @@ def load_dataset():
38
  @st.cache_data(show_spinner=False)
39
  def load_dataframe(_dataset):
40
  dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas()
 
 
 
 
 
41
  # dataframe['Post Created'] = dataframe['Post Created'].dt.tz_convert('UTC')
42
  dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
43
  return dataframe
@@ -226,6 +236,231 @@ def image_to_image(image, k=5):
226
  scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
227
  return postprocess_results(scores, samples)
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  st.title("#ditaduranuncamais Data Explorer")
230
 
231
  def check_password():
@@ -268,7 +503,7 @@ df = load_dataframe(dataset)
268
  image_model = load_img_model()
269
  text_model = load_txt_model()
270
 
271
- menu_options = ["Data exploration", "Semantic search", "Stats"]
272
  st.sidebar.markdown('# Menu')
273
  selected_menu_option = st.sidebar.radio("Select a page", menu_options)
274
 
@@ -379,7 +614,109 @@ elif selected_menu_option == "Semantic search":
379
  },
380
  hide_index=True,
381
  )
382
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  elif selected_menu_option == "Stats":
384
  st.markdown("### Time Series Analysis")
385
  # Dropdown to select variables
@@ -439,4 +776,4 @@ elif selected_menu_option == "Stats":
439
  elif corr > -0.7:
440
  st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.")
441
  else:
442
- st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.")
 
13
  )
14
  import subprocess
15
  from tempfile import NamedTemporaryFile
16
+ from itertools import combinations
17
+ import networkx as nx
18
+ import plotly.graph_objects as go
19
+ import colorcet as cc
20
+ from matplotlib.colors import rgb2hex
21
 
22
  st.set_page_config(layout="wide")
23
 
 
43
  @st.cache_data(show_spinner=False)
44
  def load_dataframe(_dataset):
45
  dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas()
46
+ # Extract hashtags ith regex and convert to set
47
+ dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1)
48
+ dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set)
49
+ # remove all hashtags that starts with 'throwback', 'thursday' or 'tbt' from the lists of hashtags per post
50
+ # dataframe['Hashtags'] = dataframe['Hashtags'].apply(lambda x: [item for item in x if not item.startswith('ditaduranuncamais')])
51
  # dataframe['Post Created'] = dataframe['Post Created'].dt.tz_convert('UTC')
52
  dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
53
  return dataframe
 
236
  scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
237
  return postprocess_results(scores, samples)
238
 
239
+ def disparity_filter(g: nx.Graph, weight: str = 'weight', alpha: float = 0.05) -> nx.Graph:
240
+ """
241
+ Computes the backbone of the input graph using the disparity filter algorithm.
242
+
243
+ The algorithm is proposed in:
244
+ M. A. Serrano, M. Boguna, and A. Vespignani,
245
+ "Extracting the Multiscale Backbone of Complex Weighted Networks",
246
+ PNAS, 106(16), pp 6483--6488 (2009).
247
+ DOI: 10.1073/pnas.0808904106
248
+
249
+ Implementation taken from https://groups.google.com/g/networkx-discuss/c/bCuHZ3qQ2po/m/QvUUJqOYDbIJ
250
+
251
+ Parameters
252
+ ----------
253
+ g : NetworkX graph
254
+ The input graph.
255
+ weight : str, optional (default='weight')
256
+ The name of the edge attribute to use as weight.
257
+ alpha : float, optional (default=0.05)
258
+ The statistical significance level for the disparity filter (p-value).
259
+
260
+ Returns
261
+ -------
262
+ backbone_graph : NetworkX graph
263
+ The backbone graph.
264
+ """
265
+ # Create an empty graph for the backbone
266
+ backbone_graph = nx.Graph()
267
+
268
+ # Iterate over all nodes in the input graph
269
+ for node in g:
270
+ # Get the degree of the node (number of edges connected to the node)
271
+ k_n = len(g[node])
272
+
273
+ # Only proceed if the node has more than one connection
274
+ if k_n > 1:
275
+ # Calculate the sum of weights of edges connected to the node
276
+ sum_w = sum(g[node][neighbor][weight] for neighbor in g[node])
277
+
278
+ # Iterate over all neighbors of the node
279
+ for neighbor in g[node]:
280
+ # Get the weight of the edge between the node and its neighbor
281
+ edge_weight = g[node][neighbor][weight]
282
+
283
+ # Calculate the proportion of the total weight that this edge represents
284
+ pij = float(edge_weight) / sum_w
285
+
286
+ # Perform the disparity filter test. If it passes, the edge is considered significant and is added to the backbone
287
+ if (1 - pij) ** (k_n - 1) < alpha:
288
+ backbone_graph.add_edge(node, neighbor, weight=edge_weight)
289
+
290
+ # Return the backbone graph
291
+ return backbone_graph
292
+
293
+ st.cache_data(show_spinner=True)
294
+ def assign_community_colors(G: nx.Graph, attr: str = 'community') -> dict:
295
+ """
296
+ Assigns a unique color to each community in the input graph.
297
+
298
+ Parameters
299
+ ----------
300
+ G : nx.Graph
301
+ The input graph.
302
+ attr : str, optional
303
+ The node attribute of the community names or indexes (default is 'community').
304
+
305
+ Returns
306
+ -------
307
+ dict
308
+ A dictionary mapping each community to a unique color.
309
+ """
310
+ glasbey_colors = cc.glasbey_hv
311
+ communities_ = set(nx.get_node_attributes(G, attr).values())
312
+ return {community: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, community in enumerate(communities_)}
313
+
314
+ st.cache_data(show_spinner=True)
315
+ def generate_hover_text(G: nx.Graph, attr: str = 'community') -> list:
316
+ """
317
+ Generates hover text for each node in the input graph.
318
+
319
+ Parameters
320
+ ----------
321
+ G : nx.Graph
322
+ The input graph.
323
+ attr : str, optional
324
+ The node attribute of the community names or indexes (default is 'community').
325
+
326
+ Returns
327
+ -------
328
+ list
329
+ A list of strings containing the hover text for each node.
330
+ """
331
+ return [f"Node: {str(node)}<br>Community: {G.nodes[node][attr] + 1}<br># of connections: {len(adjacencies)}" for node, adjacencies in G.adjacency()]
332
+
333
+ st.cache_data(show_spinner=True)
334
+ def calculate_node_sizes(G: nx.Graph) -> list:
335
+ """
336
+ Calculates the size of each node in the input graph based on its degree.
337
+
338
+ Parameters
339
+ ----------
340
+ G : nx.Graph
341
+ The input graph.
342
+
343
+ Returns
344
+ -------
345
+ list
346
+ A list of node sizes.
347
+ """
348
+ degrees = dict(G.degree())
349
+ max_degree = max(deg for node, deg in degrees.items())
350
+ return [10 + 20 * (degrees[node] / max_degree) for node in G.nodes()]
351
+
352
+ @st.cache_data(show_spinner=True)
353
+ def plot_graph(_G: nx.Graph, layout: str = "fdp"):
354
+ """
355
+ Plots a network graph with communities.
356
+
357
+ Parameters
358
+ ----------
359
+ G : nx.Graph
360
+ The input graph.
361
+ layout : str, optional
362
+ The layout algorithm to use (default is "fdp").
363
+ """
364
+ pos = nx.spring_layout(G_backbone, dim=3, seed=779)
365
+ community_colors = assign_community_colors(_G)
366
+ node_colors = [community_colors[_G.nodes[n]['community']] for n in _G.nodes]
367
+
368
+ edge_trace = go.Scatter(x=[item for sublist in [[pos[edge[0]][0], pos[edge[1]][0], None] for edge in _G.edges()] for item in sublist],
369
+ y=[item for sublist in [[pos[edge[0]][1], pos[edge[1]][1], None] for edge in _G.edges()] for item in sublist],
370
+ line=dict(width=0.5, color='#888'),
371
+ hoverinfo='none',
372
+ mode='lines')
373
+
374
+ node_trace = go.Scatter(x=[pos[n][0] for n in _G.nodes()],
375
+ y=[pos[n][1] for n in _G.nodes()],
376
+ mode='markers',
377
+ hoverinfo='text',
378
+ marker=dict(color=node_colors, size=10, line_width=2))
379
+
380
+ node_trace.text = generate_hover_text(_G)
381
+ node_trace.marker.size = calculate_node_sizes(_G)
382
+
383
+ fig = go.Figure(data=[edge_trace, node_trace],
384
+ layout=go.Layout(title='Network graph with communities',
385
+ titlefont=dict(size=16),
386
+ showlegend=False,
387
+ hovermode='closest',
388
+ margin=dict(b=20,l=5,r=5,t=40),
389
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
390
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
391
+ height=800))
392
+
393
+ # Extract node positions
394
+ Xn=[pos[k][0] for k in G_backbone.nodes()] # x-coordinates of nodes
395
+ Yn=[pos[k][1] for k in G_backbone.nodes()] # y-coordinates
396
+ Zn=[pos[k][2] for k in G_backbone.nodes()] # z-coordinates
397
+
398
+ # Extract edge positions
399
+ Xe=[]
400
+ Ye=[]
401
+ Ze=[]
402
+ for e in G_backbone.edges():
403
+ Xe+=[pos[e[0]][0],pos[e[1]][0], None] # x-coordinates of edge ends
404
+ Ye+=[pos[e[0]][1],pos[e[1]][1], None]
405
+ Ze+=[pos[e[0]][2],pos[e[1]][2], None]
406
+
407
+ # Define traces for plotly
408
+ trace1=go.Scatter3d(x=Xe,
409
+ y=Ye,
410
+ z=Ze,
411
+ mode='lines',
412
+ line=dict(color='rgb(125,125,125)', width=1),
413
+ hoverinfo='none'
414
+ )
415
+
416
+ # Map community numbers to names
417
+ community_names = {i: f"Community {i+1}" for i in range(len(communities))}
418
+
419
+ # Create hover text
420
+ hover_text = [f"{node} ({community_names[G_backbone.nodes[node]['community']]})" for node in G_backbone.nodes()]
421
+
422
+ trace2=go.Scatter3d(x=Xn,
423
+ y=Yn,
424
+ z=Zn,
425
+ mode='markers',
426
+ name='actors',
427
+ marker=dict(symbol='circle',
428
+ size=7,
429
+ color=node_colors, # pass hex colors
430
+ line=dict(color='rgb(50,50,50)', width=0.2)
431
+ ),
432
+ text=hover_text, # Use community names as hover text
433
+ hoverinfo='text'
434
+ )
435
+
436
+ axis=dict(showbackground=False,
437
+ showline=False,
438
+ zeroline=False,
439
+ showgrid=False,
440
+ showticklabels=False,
441
+ title=''
442
+ )
443
+
444
+ layout = go.Layout(
445
+ title="3D Network Graph",
446
+ width=1000,
447
+ height=1000,
448
+ showlegend=False,
449
+ scene=dict(
450
+ xaxis=dict(axis),
451
+ yaxis=dict(axis),
452
+ zaxis=dict(axis),
453
+ ),
454
+ margin=dict(
455
+ t=100
456
+ ),
457
+ hovermode='closest',
458
+ )
459
+
460
+ data=[trace1, trace2]
461
+ fig=go.Figure(data=data, layout=layout)
462
+ return fig
463
+
464
  st.title("#ditaduranuncamais Data Explorer")
465
 
466
  def check_password():
 
503
  image_model = load_img_model()
504
  text_model = load_txt_model()
505
 
506
+ menu_options = ["Data exploration", "Semantic search", "Hashtags", "Stats"]
507
  st.sidebar.markdown('# Menu')
508
  selected_menu_option = st.sidebar.radio("Select a page", menu_options)
509
 
 
614
  },
615
  hide_index=True,
616
  )
617
+ elif selected_menu_option == "Hashtags":
618
+ if 'dfx' not in st.session_state:
619
+ st.session_state.dfx = df.copy() # Make a copy of dfx
620
+ # Get a list of all unique hashtags in the DataFrame
621
+ all_hashtags = list(set([item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist]))
622
+
623
+ st.sidebar.markdown('# Hashtag co-occurrence analysis options')
624
+ # Let users select hashtags to remove
625
+ hashtags_to_remove = st.sidebar.multiselect("Hashtags to remove", all_hashtags)
626
+
627
+ col1, col2 = st.sidebar.columns(2)
628
+ # Add a button to trigger the removal operation
629
+ if col1.button("Remove hashtags"):
630
+ # If dfx does not exist in session state, create it
631
+ st.session_state.dfx['Hashtags'] = st.session_state.dfx['Hashtags'].apply(lambda x: [item for item in x if item not in hashtags_to_remove])
632
+
633
+ # Add a reset button
634
+ if col2.button("Reset"):
635
+ st.session_state.dfx = df.copy() # Reset dfx to the original DataFrame
636
+
637
+ # df2['Hashtags'] = df2['Hashtags'].apply(lambda x: [item for item in x if not item == 'ditaduranuncamais'])
638
+ # Count the number of unique hashtags
639
+ hashtags = [item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist]
640
+ # Count the number of posts per hashtag
641
+ hashtag_freq = st.session_state.dfx.explode('Hashtags').groupby('Hashtags').size().reset_index(name='counts')
642
+ # Sort the hashtags by frequency
643
+ hashtag_freq = hashtag_freq.sort_values(by='counts', ascending=False)
644
+
645
+ # Make the scatter plot
646
+ hashtags_fig = px.scatter(hashtag_freq, x='Hashtags', y='counts', log_y=True, # Set log_y to True to make the plot more readable on a log scale
647
+ labels={'Hashtags': 'Hashtags', 'counts': 'Frequency'},
648
+ title='Frequency of hashtags in #throwbackthursday posts on Instagram',
649
+ height=600) # Set the height to 600 pixels
650
+ st.markdown("### Hashtag Frequency Distribution")
651
+ st.markdown('Here we apply hashtag co-occurence analysis for mnemonic community detection. This detects communities through creating a network of hashtag pairs (which hashtags are used together in which posts) and then applying community detection algorithms on this network.')
652
+ st.plotly_chart(hashtags_fig)
653
+
654
+ weight_option = st.sidebar.radio(
655
+ 'Select weight definition',
656
+ ('Number of users that use the hashtag pairs', 'Total number of occurrences')
657
+ )
658
+
659
+ hashtag_user_pairs = [(tuple(sorted(combination)), userid) for hashtags, userid in zip(st.session_state.dfx['Hashtags'], st.session_state.dfx['User Name']) for combination in combinations(hashtags, r=2)]
660
+ # Create a DataFrame with columns 'hashtag_pair' and 'userid'
661
+ hashtag_user_df = pd.DataFrame(hashtag_user_pairs, columns=['hashtag_pair', 'User Name'])
662
+ if weight_option == 'Number of users that use the hashtag pairs':
663
+ # Group by 'hashtag_pair' and count the number of unique 'userid's
664
+ hashtag_user_df = hashtag_user_df.groupby('hashtag_pair').agg({'User Name': 'nunique'}).reset_index()
665
+ elif weight_option == 'Total number of occurrences':
666
+ # Group by 'hashtag_pair' and count the total number of occurrences
667
+ hashtag_user_df = hashtag_user_df.groupby('hashtag_pair').size().reset_index(name='User Name')
668
+ # Make edge_list from hashtag_user_df with columns 'hashtag1', 'hashtag2', and 'weight'
669
+ edge_list = hashtag_user_df.rename(columns={'hashtag_pair': 'hashtag1', 'User Name': 'weight'})
670
+ edge_list[['hashtag1', 'hashtag2']] = pd.DataFrame(edge_list['hashtag1'].tolist(), index=edge_list.index)
671
+ edge_list = edge_list[['hashtag1', 'hashtag2', 'weight']]
672
+
673
+ st.markdown("### Edge List of Hashtag Pairs")
674
+ # Create the graph using the unique users as adge attributes
675
+ G = nx.from_pandas_edgelist(edge_list, 'hashtag1', 'hashtag2', 'weight')
676
+ G_backbone = disparity_filter(G, weight='weight', alpha=0.05)
677
+ st.markdown(f'Number of nodes {len(G_backbone.nodes)}')
678
+ st.markdown(f'Number of edges {len(G_backbone.edges)}')
679
+ st.dataframe(edge_list.sort_values(by='weight', ascending=False).head(10).style.set_caption("Edge list of hashtag pairs with the highest weight"))
680
+
681
+ # Create louvain communities
682
+ communities = nx.community.louvain_communities(G_backbone, weight='weight', seed=1234)
683
+ communities = list(communities)
684
+
685
+ # Sort communities by size
686
+ communities.sort(key=len, reverse=True)
687
+
688
+ for i, community in enumerate(communities):
689
+ for node in community:
690
+ G_backbone.nodes[node]['community'] = i
691
+
692
+
693
+ # Sort community hashtags based on their weighted degree in the network
694
+ sorted_community_hashtags = [
695
+ [
696
+ hashtag
697
+ for hashtag, degree in sorted(
698
+ ((h, G.degree(h, weight='weight')) for h in community),
699
+ key=lambda x: x[1],
700
+ reverse=True
701
+ )
702
+ ]
703
+ for community in communities
704
+ ]
705
+
706
+ # Convert the sorted_community_hashtags list into a DataFrame and transpose it
707
+ sorted_community_hashtags = pd.DataFrame(sorted_community_hashtags).T
708
+
709
+ # Rename the columns of sorted_community_hashtags DataFrame
710
+ sorted_community_hashtags.columns = [f'Community {i+1}' for i in range(len(sorted_community_hashtags.columns))]
711
+
712
+ st.markdown("### Hashtag Communities")
713
+ st.markdown(f'There are {len(communities)} communities in the graph.')
714
+ st.data_editor(sorted_community_hashtags)
715
+
716
+ st.markdown("### Hashtag Network Graph")
717
+ st.plotly_chart(plot_graph(G_backbone, layout="fdp")) # fdp is relatively slow, use 'sfdp' or 'neato' for faster but denser layouts
718
+
719
+
720
  elif selected_menu_option == "Stats":
721
  st.markdown("### Time Series Analysis")
722
  # Dropdown to select variables
 
776
  elif corr > -0.7:
777
  st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.")
778
  else:
779
+ st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.")