la-metro / app.py
David Wisdom
switch Jasmine to Robertson
37efc93
import os
# stop tensorflow from printing novels to stdout
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import pickle
import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
import tensorflow as tf
import tensorflow_hub as hub
from sklearn.cluster import DBSCAN
def read_stops(p: str):
"""
Read in the .csv file of metro stops
:param p: The path to the .csv file of metro stops
"""
return pd.read_csv(p)
def read_encodings(p: str) -> tf.Tensor:
"""
Unpickle the Universal Sentence Encoder v4 encodings
and return them
This function doesn't make any attempt to patch the security holes in `pickle`.
:param p: Path to the encodings
:returns: A Tensor of the encodings with shape (number of sentences, 512)
"""
with open(p, 'rb') as f:
encodings = pickle.load(f)
return encodings
def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
"""
Cluster the sentence encodings using DBSCAN.
:param encodings: A Tensor of sentence encodings with shape
(number of sentences, 512)
:returns: a NumPy array of the cluster labels
"""
# I know the hyperparams I want from the EDA I did in the notebook
clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
return clusterer.labels_
def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
"""
Cluster the metro stops by their latitude and longitude using DBSCAN.
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
:returns: a NumPy array of the cluster labels
"""
# I know the hyperparams I want from the EDA I did in the notebook
clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
return clusterer.labels_
def plot_example(df: pd.DataFrame, labels: np.ndarray):
"""
Plot the geographic clustering
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
:param labels: a NumPy array of the cluster labels
"""
px.set_mapbox_access_token(st.secrets['mapbox_token'])
labels = labels.astype('str')
fig = px.scatter_mapbox(df, lon='longitude', lat='latitude',
hover_name='display_name',
color=labels,
zoom=8,
color_discrete_sequence=px.colors.qualitative.Dark24)
return fig
def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray):
"""
Plot the metro stops and color them based on their names
:param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns
:param labels: a NumPy array of the cluster labels
"""
px.set_mapbox_access_token(st.secrets['mapbox_token'])
venice_blvd = {'lat': 34.008350,
'lon': -118.425362}
labels = labels.astype('str')
fig = px.scatter_mapbox(df, lat='latitude', lon='longitude',
color=labels,
hover_name='display_name',
center=venice_blvd,
zoom=12,
color_discrete_sequence=px.colors.qualitative.Dark24)
# fig.show()
return fig
def main(data_path: str, enc_path: str):
df = read_stops(data_path)
# Cluster based on lat/lon
example_labels = cluster_lat_lon(df)
example_fig = plot_example(df, example_labels)
# Cluster based on the name of the stop
encodings = read_encodings(enc_path)
encoding_labels = cluster_encodings(encodings)
venice_fig = plot_venice_blvd(df, encoding_labels)
# Display the plots with Streamlit
st.write('# Cluster the stops by their position')
st.write("""First, I clustered the
stops by their geographic location.
The DBSCAN algorithm finds three clusters.
Points labeled `-1` aren't part of any cluster.
Clicking on `-1` in the legend will turn off those points.""")
st.plotly_chart(example_fig, use_container_width=True)
st.write('# Cluster the stops by their name')
st.write("""I encoded the names of all the stops using the Universal Sentence Encoder v4.
I then clustered those encodings so that I could group the stops based on their names
instead of their geographic position.
As I expected, stops on the same road end up close enough to each other that DBSCAN can cluster them together.
Sometimes, however, a stop has a name that means something to the encoder.
When that happens, the encoding ends up too far away from the rest of the stops on that road.
For example, the stops on Venice Blvd get clustered together,
but the stop "Venice / Lincoln" ends up somewhere else.
I assume it ends up somewhere else because the encoder recognizes "Lincoln"
and that meaning overpowers the "Venice" meaning enough that the encoding
is too far away from the rest of the "Venice" stops.
A few other examples on Venice Blvd are "Saint Andrews," "Harvard," and "Beethoven."
There are also a few that I don't ascribe much meaning to, such as "Girard" and "Robertson."
There's a lot more to dig into here but I'll leave it there for now.
My mind first jumps to adversarial prompts that use famous names to move the encoding
around in the encoding space.
""")
st.plotly_chart(venice_fig, use_container_width=True)
if __name__ == '__main__':
import argparse
p = argparse.ArgumentParser()
p.add_argument('--data_path',
nargs='?',
default='data/stops.csv',
help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'")
p.add_argument('--enc_path',
nargs='?',
default='data/encodings.pkl',
help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'")
args = p.parse_args()
main(**vars(args))