David Wisdom commited on
Commit
d256b25
1 Parent(s): c05b194

plot the example stops on a map as well

Browse files
Files changed (1) hide show
  1. app.py +50 -15
app.py CHANGED
@@ -15,7 +15,9 @@ from sklearn.cluster import DBSCAN
15
 
16
  def read_stops(p: str):
17
  """
18
- DOCSTRING
 
 
19
  """
20
  return pd.read_csv(p)
21
 
@@ -38,7 +40,12 @@ def read_encodings(p: str) -> tf.Tensor:
38
 
39
  def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
40
  """
41
- DOCSTRING
 
 
 
 
 
42
  """
43
  # I know the hyperparams I want from the EDA I did in the notebook
44
  clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
@@ -47,7 +54,11 @@ def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
47
 
48
  def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
49
  """
50
- DOCSTRING
 
 
 
 
51
  """
52
  # I know the hyperparams I want from the EDA I did in the notebook
53
  clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
@@ -56,26 +67,28 @@ def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
56
 
57
  def plot_example(df: pd.DataFrame, labels: np.ndarray):
58
  """
59
- DOCSTRING
 
 
 
60
  """
61
- plot_size = 800
62
  labels = labels.astype('str')
63
 
64
- fig = px.scatter(df, x='longitude', y='latitude',
65
- hover_name='display_name',
66
- color=labels,
67
- opacity=0.5,
68
- color_discrete_sequence=px.colors.qualitative.Safe,
69
- template='presentation',
70
- width=plot_size,
71
- height=plot_size)
72
- # fig.show()
73
  return fig
74
 
75
 
76
  def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray):
77
  """
78
- DOCSTRING
 
 
 
79
  """
80
  px.set_mapbox_access_token(st.secrets['mapbox_token'])
81
  venice_blvd = {'lat': 34.008350,
@@ -107,9 +120,31 @@ def main(data_path: str, enc_path: str):
107
 
108
  # Display the plots with Streamlit
109
  st.write('# Example of what DBSCAN does')
 
 
 
 
 
110
  st.plotly_chart(example_fig, use_container_width=True)
111
 
112
  st.write('# Venice Blvd')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  st.plotly_chart(venice_fig, use_container_width=True)
114
 
115
 
 
15
 
16
  def read_stops(p: str):
17
  """
18
+ Read in the .csv file of metro stops
19
+
20
+ :param p: The path to the .csv file of metro stops
21
  """
22
  return pd.read_csv(p)
23
 
 
40
 
41
  def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
42
  """
43
+ Cluster the sentence encodings using DBSCAN.
44
+
45
+ :param encodings: A Tensor of sentence encodings with shape
46
+ (number of sentences, 512)
47
+
48
+ :returns: a NumPy array of the cluster labels
49
  """
50
  # I know the hyperparams I want from the EDA I did in the notebook
51
  clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
 
54
 
55
  def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
56
  """
57
+ Cluster the metro stops by their latitude and longitude using DBSCAN.
58
+
59
+ :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
60
+
61
+ :returns: a NumPy array of the cluster labels
62
  """
63
  # I know the hyperparams I want from the EDA I did in the notebook
64
  clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
 
67
 
68
  def plot_example(df: pd.DataFrame, labels: np.ndarray):
69
  """
70
+ Plot the geographic clustering
71
+
72
+ :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
73
+ :param labels: a NumPy array of the cluster labels
74
  """
75
+ px.set_mapbox_access_token(st.secrets['mapbox_token'])
76
  labels = labels.astype('str')
77
 
78
+ fig = px.scatter_mapbox(df, x='longitude', y='latitude',
79
+ hover_name='display_name',
80
+ color=labels,
81
+ zoom=10,
82
+ color_discrete_sequence=px.colors.qualitative.Safe,
 
 
 
 
83
  return fig
84
 
85
 
86
  def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray):
87
  """
88
+ Plot the metro stops and color them based on their names
89
+
90
+ :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
91
+ :param labels: a NumPy array of the cluster labels
92
  """
93
  px.set_mapbox_access_token(st.secrets['mapbox_token'])
94
  venice_blvd = {'lat': 34.008350,
 
120
 
121
  # Display the plots with Streamlit
122
  st.write('# Example of what DBSCAN does')
123
+ st.write("""As an example of a typical DBSCAN result, I've clustered the
124
+ stops by their geographic location.
125
+ The algorithm finds three clusters.
126
+ Points labeled `-1` aren't part of any cluster.
127
+ Clicking on `-1` in the legend will turn off those points."""
128
  st.plotly_chart(example_fig, use_container_width=True)
129
 
130
  st.write('# Venice Blvd')
131
+ st.write("""I encoded the names of all the stops using the Universal Sentence Encoder v4.
132
+ I then clustered those encodings so that I could group the stops based on their names
133
+ instead of their geographic position.
134
+ As I expected, stops on the same road end up close enough to each other that DBSCAN can cluster them together.
135
+ Sometimes, however, a stop has a name that means something to the encoder.
136
+ When that happens, the encoding ends up too far away from the rest of the stops on that road.
137
+ For example, the stops on Venice Blvd get clustered together,
138
+ but the stop `Venice / Lincoln` ends up somewhere else.
139
+ I assume it ends up somewhere else because the encoder recognizes "Lincoln"
140
+ and that meaning overpowers the "Venice" meaning enough that the encoding
141
+ is too far away from the rest of the "Venice" stops.
142
+ A few other examples on Venice Blvd are "Saint Andrews," "Harvard," and "Beethoven."
143
+ There are a few that I don't ascribe much meaning to, such as "Girard" and "Jasmine."
144
+ My mind first jumps to adversarial prompts that use famous names to move the encoding
145
+ around in the encoding space.
146
+ There's a lot more to dig into here but I'll leave it there for now.
147
+ """
148
  st.plotly_chart(venice_fig, use_container_width=True)
149
 
150