File size: 5,707 Bytes
11cb781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d256b25
 
 
11cb781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d256b25
 
 
 
 
 
11cb781
 
 
 
 
 
 
 
d256b25
 
 
 
 
11cb781
 
 
 
 
 
f8f9b4e
11cb781
d256b25
 
 
 
11cb781
d256b25
11cb781
 
9b1f7b2
d256b25
 
4ebbf66
0edf155
11cb781
 
 
f8f9b4e
11cb781
d256b25
 
 
 
11cb781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d256b25
 
 
 
890c0ed
11cb781
 
 
d256b25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890c0ed
c05b194
11cb781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
# stop tensorflow from printing novels to stdout
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import pickle

import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
import tensorflow as tf
import tensorflow_hub as hub

from sklearn.cluster import DBSCAN


def read_stops(p: str):
  """
  Read in the .csv file of metro stops

  :param p: The path to the .csv file of metro stops
  """
  return pd.read_csv(p)


def read_encodings(p: str) -> tf.Tensor:
  """
  Unpickle the Universal Sentence Encoder v4 encodings
  and return them
  
  This function doesn't make any attempt to patch the security holes in `pickle`.
  
  :param p: Path to the encodings
 
  :returns: A Tensor of the encodings with shape (number of sentences, 512)
  """
  with open(p, 'rb') as f:
    encodings = pickle.load(f)
  return encodings


def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
  """
  Cluster the sentence encodings using DBSCAN.

  :param encodings: A Tensor of sentence encodings with shape
                    (number of sentences, 512)

  :returns: a NumPy array of the cluster labels
  """
  # I know the hyperparams I want from the EDA I did in the notebook
  clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
  return clusterer.labels_


def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
  """
  Cluster the metro stops by their latitude and longitude using DBSCAN.

  :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns

  :returns: a NumPy array of the cluster labels
  """
  # I know the hyperparams I want from the EDA I did in the notebook
  clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
  return clusterer.labels_


def plot_example(df: pd.DataFrame, labels: np.ndarray): 
  """
  Plot the geographic clustering

  :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
  :param labels: a NumPy array of the cluster labels
  """ 
  px.set_mapbox_access_token(st.secrets['mapbox_token'])
  labels = labels.astype('str')

  fig = px.scatter_mapbox(df, lon='longitude', lat='latitude',
                          hover_name='display_name',
                          color=labels,
                          zoom=9,
                          color_discrete_sequence=px.colors.qualitative.Dark24)
  return fig
  

def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray):
  """
  Plot the metro stops and color them based on their names

  :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
  :param labels: a NumPy array of the cluster labels
  """
  px.set_mapbox_access_token(st.secrets['mapbox_token'])
  venice_blvd = {'lat': 34.008350,
                 'lon': -118.425362}
  labels = labels.astype('str')

  fig = px.scatter_mapbox(df, lat='latitude', lon='longitude',
                          color=labels,
                          hover_name='display_name', 
                          center=venice_blvd,
                          zoom=12,
                          color_discrete_sequence=px.colors.qualitative.Dark24)

  # fig.show()
  return fig
  
  
def main(data_path: str, enc_path: str):
  df = read_stops(data_path)

  # Cluster based on lat/lon
  example_labels = cluster_lat_lon(df)
  example_fig = plot_example(df, example_labels)

  # Cluster based on the name of the stop
  encodings = read_encodings(enc_path)
  encoding_labels = cluster_encodings(encodings)
  venice_fig = plot_venice_blvd(df, encoding_labels)

  # Display the plots with Streamlit
  st.write('# Example of what DBSCAN does')
  st.write("""As an example of a typical DBSCAN result, I've clustered the
stops by their geographic location.
The algorithm finds three clusters.
Points labeled `-1` aren't part of any cluster.
Clicking on `-1` in the legend will turn off those points.""")
  st.plotly_chart(example_fig, use_container_width=True)

  st.write('# Venice Blvd')
  st.write("""I encoded the names of all the stops using the Universal Sentence Encoder v4.
I then clustered those encodings so that I could group the stops based on their names
instead of their geographic position.
As I expected, stops on the same road end up close enough to each other that DBSCAN can cluster them together.
Sometimes, however, a stop has a name that means something to the encoder.
When that happens, the encoding ends up too far away from the rest of the stops on that road.
For example, the stops on Venice Blvd get clustered together, 
but the stop `Venice / Lincoln` ends up somewhere else.
I assume it ends up somewhere else because the encoder recognizes "Lincoln" 
and that meaning overpowers the "Venice" meaning enough that the encoding 
is too far away from the rest of the "Venice" stops.
A few other examples on Venice Blvd are "Saint Andrews," "Harvard," and "Beethoven."
There are a few that I don't ascribe much meaning to, such as "Girard" and "Jasmine."
My mind first jumps to adversarial prompts that use famous names to move the encoding
around in the encoding space.
There's a lot more to dig into here but I'll leave it there for now. 
""")
  st.plotly_chart(venice_fig, use_container_width=True)
  

if __name__ == '__main__':
  import argparse

  p = argparse.ArgumentParser()
  p.add_argument('--data_path',
                 nargs='?',
                 default='data/stops.csv',
                 help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'")
  p.add_argument('--enc_path',
                 nargs='?',
                 default='data/encodings.pkl',
                 help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'")
  args = p.parse_args()

  main(**vars(args))