import os # stop tensorflow from printing novels to stdout os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import pickle import numpy as np import pandas as pd import plotly.express as px import streamlit as st import tensorflow as tf import tensorflow_hub as hub from sklearn.cluster import DBSCAN def read_stops(p: str): """ DOCSTRING """ return pd.read_csv(p) def read_encodings(p: str) -> tf.Tensor: """ Unpickle the Universal Sentence Encoder v4 encodings and return them This function doesn't make any attempt to patch the security holes in `pickle`. :param p: Path to the encodings :returns: A Tensor of the encodings with shape (number of sentences, 512) """ with open(p, 'rb') as f: encodings = pickle.load(f) return encodings def cluster_encodings(encodings: tf.Tensor) -> np.ndarray: """ DOCSTRING """ # I know the hyperparams I want from the EDA I did in the notebook clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings) return clusterer.labels_ def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray: """ DOCSTRING """ # I know the hyperparams I want from the EDA I did in the notebook clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']]) return clusterer.labels_ def plot_example(df: pd.DataFrame, labels: np.ndarray) -> px.Figure: """ DOCSTRING """ plot_size = 800 labels = labels.astype('str') fig = px.scatter(df, x='longitude', y='latitude', hover_name='display_name', color=labels, opacity=0.5, color_discrete_sequence=px.colors.qualitative.Safe, template='presentation', width=plot_size, height=plot_size) # fig.show() return fig def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray) -> px.Figure: """ DOCSTRING """ px.set_mapbox_access_token(st.secrets['mapbox_token']) venice_blvd = {'lat': 34.008350, 'lon': -118.425362} labels = labels.astype('str') fig = px.scatter_mapbox(df, lat='latitude', lon='longitude', color=labels, hover_name='display_name', center=venice_blvd, zoom=12, color_discrete_sequence=px.colors.qualitative.Dark24) # fig.show() return fig def main(data_path: str, enc_path: str): df = read_stops(data_path) # Cluster based on lat/lon example_labels = cluster_lat_lon(df) example_fig = plot_example(df, example_labels) # Cluster based on the name of the stop encodings = read_encodings(enc_path) encoding_labels = cluster_encodings(encodings) venice_fig = plot_venice_blvd(df, encoding_labels) # Display the plots with Streamlit st.write('# Example of what DBSCAN does') st.plotly_chart(example_fig, use_container_width=True) st.write('# Venice Blvd') st.plotly_chart(example_fig, use_container_width=True) if __name__ == '__main__': import argparse p = argparse.ArgumentParser() p.add_argument('--data_path', nargs='?', default='data/stops.csv', help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'") p.add_argument('--enc_path', nargs='?', default='data/encodings.pkl', help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'") args = p.parse_args() main(**vars(args))