Ronak Ramachandran commited on
Commit
e0313ac
1 Parent(s): 5e527c3
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.txt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,9 +1,321 @@
1
- import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
7
-
8
  if __name__ == "__main__":
9
- demo.launch()
 
 
1
+ mport gradio as gr
2
 
3
+ import PIL
4
+ import numpy as np
5
+
6
+ import scipy
7
+ from scipy.stats import gaussian_kde
8
+ from scipy.optimize import curve_fit
9
+
10
+ import pandas as pd
11
+
12
+ from sklearn.preprocessing import StandardScaler
13
+ from sklearn.decomposition import PCA
14
+ from sklearn.neighbors import KernelDensity
15
+
16
+ import matplotlib as mpl
17
+ import matplotlib.pyplot as plt
18
+
19
+ import copy
20
+
21
+ df = pd.read_csv(
22
+ './gene_tpm_brain_cerebellar_hemisphere_log2minus1NEW.txt', sep='\t')
23
+ gene_table = df.set_index('Description').drop(
24
+ columns=['id', 'Name']).T.reset_index(drop=True)
25
+
26
+ # ===============================================================================================
27
+ # ===============================================================================================
28
+ # ===============================================================================================
29
+
30
+
31
+ def plot_hist_gauss(col, ax=None, orientation='vertical', label=''):
32
+ show = True if ax is None else False
33
+
34
+ ax = col.plot.hist(orientation=orientation, density=True,
35
+ alpha=0.2, ax=ax, subplots=False)
36
+
37
+ hist, bin_edges = np.histogram(col, density=True)
38
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
39
+
40
+ def gauss(x, A, mu, sigma):
41
+ return A * np.exp(-(x - mu)**2 / (2. * sigma**2))
42
+
43
+ p0 = [1, 5, 1]
44
+ popt, pcov = curve_fit(gauss, bin_centers, hist, p0=p0) # hist
45
+ A, mu, sigma = popt
46
+
47
+ granularity = 100
48
+ x = np.linspace(col.min(), col.max(), granularity)
49
+ if orientation == 'horizontal':
50
+ ax.plot(gauss(x, *popt), x, c='C0', label='Fitted data')
51
+ ax.hlines(mu, *ax.get_xlim(), colors='C3', label='Fitted mean')
52
+ ax.set_ylabel(label)
53
+ else:
54
+ ax.plot(x, gauss(x, *popt), c='C0', label='Fitted data')
55
+ ax.vlines(mu, *ax.get_ylim(), colors='C3', label='Fitted mean')
56
+ ax.set_xlabel(label)
57
+
58
+ if show:
59
+ plt.show()
60
+
61
+ return popt
62
+
63
+
64
+ def plot_gene(gene, ax=None, orientation='vertical'):
65
+ plot_hist_gauss(gene_table[gene], ax=ax,
66
+ orientation=orientation, label=gene)
67
+
68
+ # ===============================================================================================
69
+ # ===============================================================================================
70
+ # ===============================================================================================
71
+
72
+
73
+ def plot_genes(x_gene=None, y_gene=None, ax=None, mode='raw', gene_table=gene_table):
74
+ """
75
+ Produces a scatterplot of the TPM (Transcriptions Per Million) of two genes,
76
+ and fits data to bivariate Gaussian which is also plotted.
77
+
78
+ Parameters
79
+ ----------
80
+ x_gene : str
81
+ The common name of the gene to be plotted along the x-axis.
82
+ y_gene : str
83
+ The common name of the gene to be plotted along the y-axis.
84
+ ax : matplotlib axes object, default None
85
+ An axes of the current figure.
86
+ mode : str, default 'raw'
87
+ The mode of plotting:
88
+
89
+ - 'raw' : plot data as is
90
+ - 'norm' : normalize and recenter before plotting
91
+
92
+ gene_table : pandas DataFrame, default global gene_table
93
+ A table containing the two genes to be plotted as columns
94
+
95
+ Returns
96
+ -------
97
+ plotted_data : pandas DataFrame
98
+ The two columns of data that were actually plotted
99
+ A : float
100
+ Amplitude of optimal bivariate Gaussian
101
+ x0 : float
102
+ x mean of optimal bivariate Gaussian
103
+ y0 : float
104
+ y mean of optimal bivariate Gaussian
105
+ sigma_x : float
106
+ Standard deviation along x axis of optimal bivariate Gaussian
107
+ sigma_y : float
108
+ Standard deviation along y axis of optimal bivariate Gaussian
109
+ rho : float
110
+ Pearson correlation coefficient of optimal bivariate Gaussian
111
+ z_offset : float
112
+ Additive offset of optimal bivariate Gaussian
113
+ """
114
+
115
+ show = True if ax is None else False
116
+ if ax is None:
117
+ ax = plt.axes()
118
+ ax.set_aspect('equal', adjustable='box')
119
+ if x_gene is not None and y_gene is not None:
120
+ two_cols = gene_table.loc[:, [x_gene, y_gene]]
121
+ else: # testing
122
+ print('WARNING: plot_genes requires two gene names as input. '
123
+ 'You have omitted at least one, so random test data will '
124
+ 'be plotted instead.')
125
+ x_gene, y_gene = 'x', 'y'
126
+ test_dist = np.random.default_rng().multivariate_normal(
127
+ mean=[100, 200], cov=[[1, 0.9], [0.9, np.sqrt(3)]], size=(1000))
128
+ two_cols = pd.DataFrame(data=test_dist, columns=[x_gene, y_gene])
129
+
130
+ # Mean and density ---------------------------------------------------------
131
+
132
+ mean = two_cols.mean()
133
+
134
+ data_for_kde = two_cols.values.T
135
+ density_estimator = gaussian_kde(data_for_kde)
136
+ z = density_estimator(data_for_kde)
137
+
138
+ # Fit to 2D Gaussian =======================================================
139
+
140
+ def bivariate_Gaussian(xy, A, x0, y0, sigma_x, sigma_y, rho, z_offset):
141
+ x, y = xy
142
+
143
+ # A should really be divided by (2*np.pi*sigma_x*sigma_y*np.sqrt(1-rho**2))
144
+ a = 1 / (2 * (1 - rho**2) * sigma_x**2)
145
+ b = - rho / ((1 - rho**2) * sigma_x * sigma_y)
146
+ c = 1 / (2 * (1 - rho**2) * sigma_y**2)
147
+ g = z_offset + A * \
148
+ np.exp(-(a * (x - x0)**2 + b * (x - x0) * (y - y0) + c * (y - y0)**2))
149
+
150
+ return g.ravel()
151
+
152
+ gran = 400 # granularity
153
+ x = np.linspace(two_cols[x_gene].min(), two_cols[x_gene].max(), gran)
154
+ y = np.linspace(two_cols[y_gene].min(), two_cols[y_gene].max(), gran)
155
+ pts = np.transpose(np.dstack(np.meshgrid(x, y)),
156
+ axes=[2, 0, 1]).reshape(2, -1)
157
+
158
+ p0 = (1, mean[0], mean[1], 1, 1, 0, 0)
159
+ popt, pcov = curve_fit(bivariate_Gaussian, pts,
160
+ density_estimator(pts), p0=p0)
161
+ A, x0, y0, sigma_x, sigma_y, rho, z_offset = popt
162
+
163
+ cov = np.array(
164
+ [[sigma_x**2, rho * sigma_x * sigma_y],
165
+ [rho * sigma_x * sigma_y, sigma_y**2]])
166
+ eigenvalues, eigenvectors = np.linalg.eig(cov)
167
+ # eigvals are variances along ellipse axes, eigvects are direction of axes
168
+ scaled_eigvects = np.sqrt(eigenvalues) * eigenvectors
169
+
170
+ # Plots ====================================================================
171
+
172
+ plotted_data = gene_table
173
+
174
+ if mode == 'raw':
175
+ # --- Plot Data ---
176
+ two_cols.plot.scatter(x=x_gene, y=y_gene, c=z,
177
+ s=2, ylabel=y_gene, ax=ax)
178
+
179
+ # --- Plot Fitted Gaussian ---
180
+ pts = pts.reshape(2, gran, gran)
181
+ data_fitted = bivariate_Gaussian(pts, *popt).reshape(gran, gran)
182
+
183
+ # contour
184
+ ax.contour(pts[0], pts[1], data_fitted, 8,
185
+ cmap='viridis', zorder=0, alpha=.5)
186
+
187
+ # center
188
+ ax.plot(x0, y0, 'rx')
189
+
190
+ # gene axes
191
+ ax.quiver([x0, x0], [y0, y0], [1, 0], [0, 1], angles='xy', scale_units='xy',
192
+ width=0.005, scale=1, color=['magenta', 'violet'], alpha=0.35)
193
+
194
+ # ellipse axes
195
+ ax.quiver([x0, x0], [y0, y0], *scaled_eigvects, angles='xy', scale_units='xy',
196
+ width=0.005, scale=1, color=['red', 'firebrick'], alpha=0.35)
197
+
198
+ plotted_data = two_cols
199
+
200
+ # --------------------------------------------------------------------------
201
+
202
+ elif mode == 'norm':
203
+ inv_cov = np.linalg.inv(scaled_eigvects)
204
+ recentered_data = two_cols.values - [x0, y0]
205
+ normed_data = recentered_data @ inv_cov.T
206
+ normed_two_cols = pd.DataFrame(
207
+ data=normed_data, columns=[x_gene, y_gene])
208
+
209
+ # --- Plot Data ---
210
+ normed_two_cols.plot.scatter(x=x_gene, y=y_gene, c=z, s=2, ax=ax,
211
+ xlabel='minor axis',
212
+ ylabel='major axis')
213
+
214
+ # --- Plot Fitted Gaussian ---
215
+ x = np.linspace(normed_two_cols[x_gene].min(),
216
+ normed_two_cols[x_gene].max(), gran)
217
+ y = np.linspace(normed_two_cols[y_gene].min(),
218
+ normed_two_cols[y_gene].max(), gran)
219
+ pts = np.transpose(np.dstack(np.meshgrid(x, y)), axes=[2, 0, 1])
220
+
221
+ pts = pts.reshape(2, gran, gran)
222
+ data_fitted = bivariate_Gaussian(pts, A, 0, 0, 1, 1, 0, z_offset)
223
+ data_fitted = data_fitted.reshape(gran, gran)
224
+
225
+ # contour
226
+ ax.contour(pts[0], pts[1], data_fitted, 8,
227
+ cmap='viridis', zorder=0, alpha=.5)
228
+
229
+ # center
230
+ ax.plot(0, 0, 'rx')
231
+
232
+ # gene axes
233
+ ax.quiver([0, 0], [0, 0], *inv_cov, angles='xy', scale_units='xy',
234
+ width=0.005, scale=1, color=['magenta', 'violet'], alpha=0.35)
235
+
236
+ # ellipse axes
237
+ ax.quiver([0, 0], [0, 0], [1, 0], [0, 1], angles='xy', scale_units='xy',
238
+ width=0.005, scale=1, color=['red', 'firebrick'], alpha=0.35)
239
+
240
+ plotted_data = normed_two_cols
241
+
242
+ # ==========================================================================
243
+
244
+ if show:
245
+ plt.show()
246
+
247
+ return (plotted_data,
248
+ A, x0, y0, sigma_x, sigma_y, rho, z_offset) # optimal gaussian params
249
+
250
+ # ===============================================================================================
251
+ # ===============================================================================================
252
+ # ===============================================================================================
253
+
254
+
255
+ def plot_scatter_hist(x_gene, y_gene, mode='raw'):
256
+ fig = plt.figure(layout='constrained')
257
+ ax = fig.add_gridspec(top=0.75, right=0.75).subplots()
258
+ # ax.set_aspect('equal', adjustable='box') # ax.set(aspect=1)
259
+ ax_histx = ax.inset_axes([0, 1.05, 1, 0.25], sharex=ax)
260
+ ax_histy = ax.inset_axes([1.05, 0, 0.25, 1], sharey=ax)
261
+
262
+ ax_histx.tick_params(axis="x", labelbottom=False)
263
+ ax_histy.tick_params(axis="y", labelleft=False)
264
+
265
+ plot_result = plot_genes(x_gene, y_gene, ax=ax, mode=mode)
266
+ plotted_data = plot_result[0]
267
+ x_A, x_mu, x_sigma = plot_hist_gauss(plotted_data[x_gene], ax=ax_histx)
268
+ y_A, y_mu, y_sigma = plot_hist_gauss(plotted_data[y_gene], ax=ax_histy,
269
+ orientation='horizontal')
270
+ ax_histx.set_ylabel('Freq')
271
+ ax_histy.set_xlabel('Freq')
272
+
273
+ ax.vlines(x_mu, *ax.get_ylim(),
274
+ label=f'{x_gene} mean', colors='C3', zorder=0)
275
+ ax.hlines(y_mu, *ax.get_xlim(),
276
+ label=f'{y_gene} mean', colors='C3', zorder=0)
277
+
278
+ # ax.fill_between([plotted_data[x_gene].min(), plotted_data[x_gene].max()],
279
+ # *ax.get_ylim(), color='C0', alpha=0.01, lw=0)
280
+ # ax.fill_betweenx([plotted_data[y_gene].min(), plotted_data[y_gene].max()],
281
+ # *ax.get_xlim(), color='C0', alpha=0.01, lw=0)
282
+
283
+
284
+ def create_correct_gene_plot(genes, mode):
285
+ if len(genes) == 0:
286
+ raise gr.Error("Please select at least one gene to plot.")
287
+ elif len(genes) == 1:
288
+ plot_gene(gene)
289
+ elif len(genes) == 2:
290
+ mode = 'norm' if mode else None
291
+ plot_scatter_hist(genes, mode)
292
+ else:
293
+ raise gr.Error("Cannot plot more than two genes at a time.")
294
+
295
+ fig = plt.gcf()
296
+
297
+ return PIL.Image.frombytes(
298
+ 'RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
299
+
300
+
301
+ demo = gr.Interface(
302
+ create_correct_gene_plot,
303
+ [
304
+ gr.Dropdown(
305
+ gene_table.columns, value=["APP", "PSENEN"], multiselect=True, label="Genes", info="Select one or two genes to plot."
306
+ ),
307
+ gr.Checkbox(label="Normalize",
308
+ info="Recenter and normalize the Gaussian for two genes."),
309
+ ],
310
+ "image",
311
+ # examples=[
312
+ # [2, "cat", ["Japan", "Pakistan"], "park", ["ate", "swam"], True],
313
+ # [4, "dog", ["Japan"], "zoo", ["ate", "swam"], False],
314
+ # [10, "bird", ["USA", "Pakistan"], "road", ["ran"], False],
315
+ # [8, "cat", ["Pakistan"], "zoo", ["ate"], True],
316
+ # ]
317
+ )
318
 
 
 
319
  if __name__ == "__main__":
320
+ demo.launch()
321
+
gene_tpm_brain_cerebellar_hemisphere_log2minus1NEW.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7146a7abf52c322bbd46760cb393cbb4d0dc7ae20bd1ddc23c62e7553757537e
3
+ size 99254578