import gradio as gr import PIL import numpy as np import scipy from scipy.stats import gaussian_kde from scipy.optimize import curve_fit import pandas as pd from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA from sklearn.neighbors import KernelDensity import matplotlib as mpl import matplotlib.pyplot as plt import copy df = pd.read_csv( './gene_tpm_brain_cerebellar_hemisphere_log2minus1NEW.txt', sep='\t') gene_table = df.set_index('Description').drop( columns=['id', 'Name']).T.reset_index(drop=True) # =============================================================================================== # =============================================================================================== # =============================================================================================== def plot_hist_gauss(col, ax=None, orientation='vertical', label=''): show = True if ax is None else False ax = col.plot.hist(orientation=orientation, density=True, alpha=0.2, ax=ax, subplots=False) hist, bin_edges = np.histogram(col, density=True) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 def gauss(x, A, mu, sigma): return A * np.exp(-(x - mu)**2 / (2. * sigma**2)) p0 = [1, 5, 1] popt, pcov = curve_fit(gauss, bin_centers, hist, p0=p0) # hist A, mu, sigma = popt granularity = 100 x = np.linspace(col.min(), col.max(), granularity) if orientation == 'horizontal': ax.plot(gauss(x, *popt), x, c='C0', label='Fitted data') ax.hlines(mu, *ax.get_xlim(), colors='C3', label='Fitted mean') ax.set_ylabel(label) else: ax.plot(x, gauss(x, *popt), c='C0', label='Fitted data') ax.vlines(mu, *ax.get_ylim(), colors='C3', label='Fitted mean') ax.set_xlabel(label) if show: plt.show() return popt def plot_gene(gene, ax=None, orientation='vertical'): plot_hist_gauss(gene_table[gene], ax=ax, orientation=orientation, label=gene) # =============================================================================================== # =============================================================================================== # =============================================================================================== def plot_genes(x_gene=None, y_gene=None, ax=None, mode='raw', gene_table=gene_table): """ Produces a scatterplot of the TPM (Transcriptions Per Million) of two genes, and fits data to bivariate Gaussian which is also plotted. Parameters ---------- x_gene : str The common name of the gene to be plotted along the x-axis. y_gene : str The common name of the gene to be plotted along the y-axis. ax : matplotlib axes object, default None An axes of the current figure. mode : str, default 'raw' The mode of plotting: - 'raw' : plot data as is - 'norm' : normalize and recenter before plotting gene_table : pandas DataFrame, default global gene_table A table containing the two genes to be plotted as columns Returns ------- plotted_data : pandas DataFrame The two columns of data that were actually plotted A : float Amplitude of optimal bivariate Gaussian x0 : float x mean of optimal bivariate Gaussian y0 : float y mean of optimal bivariate Gaussian sigma_x : float Standard deviation along x axis of optimal bivariate Gaussian sigma_y : float Standard deviation along y axis of optimal bivariate Gaussian rho : float Pearson correlation coefficient of optimal bivariate Gaussian z_offset : float Additive offset of optimal bivariate Gaussian """ show = True if ax is None else False if ax is None: ax = plt.axes() ax.set_aspect('equal', adjustable='box') if x_gene is not None and y_gene is not None: two_cols = gene_table.loc[:, [x_gene, y_gene]] else: # testing print('WARNING: plot_genes requires two gene names as input. ' 'You have omitted at least one, so random test data will ' 'be plotted instead.') x_gene, y_gene = 'x', 'y' test_dist = np.random.default_rng().multivariate_normal( mean=[100, 200], cov=[[1, 0.9], [0.9, np.sqrt(3)]], size=(1000)) two_cols = pd.DataFrame(data=test_dist, columns=[x_gene, y_gene]) # Mean and density --------------------------------------------------------- mean = two_cols.mean() data_for_kde = two_cols.values.T density_estimator = gaussian_kde(data_for_kde) z = density_estimator(data_for_kde) # Fit to 2D Gaussian ======================================================= def bivariate_Gaussian(xy, A, x0, y0, sigma_x, sigma_y, rho, z_offset): x, y = xy # A should really be divided by (2*np.pi*sigma_x*sigma_y*np.sqrt(1-rho**2)) a = 1 / (2 * (1 - rho**2) * sigma_x**2) b = - rho / ((1 - rho**2) * sigma_x * sigma_y) c = 1 / (2 * (1 - rho**2) * sigma_y**2) g = z_offset + A * \ np.exp(-(a * (x - x0)**2 + b * (x - x0) * (y - y0) + c * (y - y0)**2)) return g.ravel() gran = 400 # granularity x = np.linspace(two_cols[x_gene].min(), two_cols[x_gene].max(), gran) y = np.linspace(two_cols[y_gene].min(), two_cols[y_gene].max(), gran) pts = np.transpose(np.dstack(np.meshgrid(x, y)), axes=[2, 0, 1]).reshape(2, -1) p0 = (1, mean[0], mean[1], 1, 1, 0, 0) popt, pcov = curve_fit(bivariate_Gaussian, pts, density_estimator(pts), p0=p0) A, x0, y0, sigma_x, sigma_y, rho, z_offset = popt cov = np.array( [[sigma_x**2, rho * sigma_x * sigma_y], [rho * sigma_x * sigma_y, sigma_y**2]]) eigenvalues, eigenvectors = np.linalg.eig(cov) # eigvals are variances along ellipse axes, eigvects are direction of axes scaled_eigvects = np.sqrt(eigenvalues) * eigenvectors # Plots ==================================================================== plotted_data = gene_table if mode == 'raw': # --- Plot Data --- two_cols.plot.scatter(x=x_gene, y=y_gene, c=z, s=2, ylabel=y_gene, ax=ax) # --- Plot Fitted Gaussian --- pts = pts.reshape(2, gran, gran) data_fitted = bivariate_Gaussian(pts, *popt).reshape(gran, gran) # contour ax.contour(pts[0], pts[1], data_fitted, 8, cmap='viridis', zorder=0, alpha=.5) # center ax.plot(x0, y0, 'rx') # gene axes ax.quiver([x0, x0], [y0, y0], [1, 0], [0, 1], angles='xy', scale_units='xy', width=0.005, scale=1, color=['magenta', 'violet'], alpha=0.35) # ellipse axes ax.quiver([x0, x0], [y0, y0], *scaled_eigvects, angles='xy', scale_units='xy', width=0.005, scale=1, color=['red', 'firebrick'], alpha=0.35) plotted_data = two_cols # -------------------------------------------------------------------------- elif mode == 'norm': inv_cov = np.linalg.inv(scaled_eigvects) recentered_data = two_cols.values - [x0, y0] normed_data = recentered_data @ inv_cov.T normed_two_cols = pd.DataFrame( data=normed_data, columns=[x_gene, y_gene]) # --- Plot Data --- normed_two_cols.plot.scatter(x=x_gene, y=y_gene, c=z, s=2, ax=ax, xlabel='minor axis', ylabel='major axis') # --- Plot Fitted Gaussian --- x = np.linspace(normed_two_cols[x_gene].min(), normed_two_cols[x_gene].max(), gran) y = np.linspace(normed_two_cols[y_gene].min(), normed_two_cols[y_gene].max(), gran) pts = np.transpose(np.dstack(np.meshgrid(x, y)), axes=[2, 0, 1]) pts = pts.reshape(2, gran, gran) data_fitted = bivariate_Gaussian(pts, A, 0, 0, 1, 1, 0, z_offset) data_fitted = data_fitted.reshape(gran, gran) # contour ax.contour(pts[0], pts[1], data_fitted, 8, cmap='viridis', zorder=0, alpha=.5) # center ax.plot(0, 0, 'rx') # gene axes ax.quiver([0, 0], [0, 0], *inv_cov, angles='xy', scale_units='xy', width=0.005, scale=1, color=['magenta', 'violet'], alpha=0.35) # ellipse axes ax.quiver([0, 0], [0, 0], [1, 0], [0, 1], angles='xy', scale_units='xy', width=0.005, scale=1, color=['red', 'firebrick'], alpha=0.35) plotted_data = normed_two_cols # ========================================================================== if show: plt.show() return (plotted_data, A, x0, y0, sigma_x, sigma_y, rho, z_offset) # optimal gaussian params # =============================================================================================== # =============================================================================================== # =============================================================================================== def plot_scatter_hist(x_gene, y_gene, mode='raw'): fig = plt.figure(layout='constrained') ax = fig.add_gridspec(top=0.75, right=0.75).subplots() # ax.set_aspect('equal', adjustable='box') # ax.set(aspect=1) ax_histx = ax.inset_axes([0, 1.05, 1, 0.25], sharex=ax) ax_histy = ax.inset_axes([1.05, 0, 0.25, 1], sharey=ax) ax_histx.tick_params(axis="x", labelbottom=False) ax_histy.tick_params(axis="y", labelleft=False) plot_result = plot_genes(x_gene, y_gene, ax=ax, mode=mode) plotted_data = plot_result[0] x_A, x_mu, x_sigma = plot_hist_gauss(plotted_data[x_gene], ax=ax_histx) y_A, y_mu, y_sigma = plot_hist_gauss(plotted_data[y_gene], ax=ax_histy, orientation='horizontal') ax_histx.set_ylabel('Freq') ax_histy.set_xlabel('Freq') ax.vlines(x_mu, *ax.get_ylim(), label=f'{x_gene} mean', colors='C3', zorder=0) ax.hlines(y_mu, *ax.get_xlim(), label=f'{y_gene} mean', colors='C3', zorder=0) # ax.fill_between([plotted_data[x_gene].min(), plotted_data[x_gene].max()], # *ax.get_ylim(), color='C0', alpha=0.01, lw=0) # ax.fill_betweenx([plotted_data[y_gene].min(), plotted_data[y_gene].max()], # *ax.get_xlim(), color='C0', alpha=0.01, lw=0) def plt_to_img(): fig = plt.gcf() fig.canvas.draw() return PIL.Image.frombytes( 'RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) def random_gene(): return np.random.default_rng().choice(gene_table.columns.values) def plot_one_gene(gene): if gene is None or gene == '': gene = random_gene() plot_gene(gene.upper()) return plt_to_img(), f"{gene}" def plot_two_genes(x_gene, y_gene, is_normed): if x_gene is None or x_gene == '': x_gene = random_gene() if y_gene is None or y_gene == '': y_gene = random_gene() mode = 'norm' if is_normed else 'raw' plot_scatter_hist(x_gene.upper(), y_gene.upper(), mode) return plt_to_img(), f"{x_gene} vs. {y_gene}" def picker(only_one_gene, x_gene, y_gene, is_normed): plt.close('all') if only_one_gene: return plot_one_gene(x_gene) return plot_two_genes(x_gene, y_gene, is_normed) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): only_one_gene = gr.Checkbox(label="Only plot Gene 1", info="By default, two genes are plotted.") x_gene = gr.Textbox(label='Gene 1', value='APP') y_gene = gr.Textbox(label='Gene 2', value='PSENEN') is_normed = gr.Checkbox(label="Normalize", info="Recenter and normalize the Gaussian for two genes.") plot_button = gr.Button("Plot") with gr.Column(): image_output = gr.Image() text_output = gr.Textbox() plot_button.click(picker, inputs=[only_one_gene, x_gene, y_gene, is_normed], outputs=[image_output,text_output]) if __name__ == "__main__": demo.launch()