from sklearn.decomposition import PCA import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np from scipy import stats import gradio as gr e = np.exp(1) np.random.seed(4) def pdf(x): return 0.5 * (stats.norm(scale=0.25 / e).pdf(x) + stats.norm(scale=4 / e).pdf(x)) y = np.random.normal(scale=0.5, size=(30000)) x = np.random.normal(scale=0.5, size=(30000)) z = np.random.normal(scale=0.1, size=len(x)) density = pdf(x) * pdf(y) pdf_z = pdf(5 * z) density *= pdf_z a = x + y b = 2 * y c = a - b + z norm = np.sqrt(a.var() + b.var()) a /= norm b /= norm def plot_figs(fig_num, elev, azim): fig = plt.figure() plt.clf() ax = fig.add_subplot(111, projection="3d", elev=elev, azim=azim) ax.set_position([0, 0, 0.95, 1]) ax.scatter(a[::10], b[::10], c[::10], c=density[::10], marker="+", alpha=0.4) Y = np.c_[a, b, c] # Using SciPy's SVD, this would be: # _, pca_score, Vt = scipy.linalg.svd(Y, full_matrices=False) pca = PCA(n_components=3) pca.fit(Y) V = pca.components_.T x_pca_axis, y_pca_axis, z_pca_axis = 3 * V x_pca_plane = np.r_[x_pca_axis[:2], -x_pca_axis[1::-1]] y_pca_plane = np.r_[y_pca_axis[:2], -y_pca_axis[1::-1]] z_pca_plane = np.r_[z_pca_axis[:2], -z_pca_axis[1::-1]] x_pca_plane.shape = (2, 2) y_pca_plane.shape = (2, 2) z_pca_plane.shape = (2, 2) ax.plot_surface(x_pca_plane, y_pca_plane, z_pca_plane) ax.xaxis.set_ticklabels([]) ax.yaxis.set_ticklabels([]) ax.zaxis.set_ticklabels([]) plt.savefig(f"{fig_num}.png") return fig def make_plot(plot_type): if plot_type == "Very flat direction": elev = -40 azim = -80 fig_num = 1 else: elev = 30 azim = 20 fig_num = 2 plot_figs(fig_num, elev, azim) title = "Principal components analysis (PCA)" with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown("These figures aid in illustrating how a point cloud can be \ very flat in one direction–which is where PCA comes in to choose a direction that is not flat.") button = gr.Radio(label="Plot type", choices=['Very flat direction', 'Not flat direction'], value='Very flat direction') plot = gr.Plot(label="Plot") button.change(make_plot, inputs=button, outputs=[plot]) demo.load(make_plot, inputs=[button], outputs=[plot]) demo.launch()