File size: 11,222 Bytes
e0313ac
1f745c4
e0313ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f745c4
c1076fc
e0313ac
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
mport gradio as gr

import PIL
import numpy as np

import scipy
from scipy.stats import gaussian_kde
from scipy.optimize import curve_fit

import pandas as pd

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity

import matplotlib as mpl
import matplotlib.pyplot as plt

import copy

df = pd.read_csv(
    './gene_tpm_brain_cerebellar_hemisphere_log2minus1NEW.txt', sep='\t')
gene_table = df.set_index('Description').drop(
    columns=['id', 'Name']).T.reset_index(drop=True)

# ===============================================================================================
# ===============================================================================================
# ===============================================================================================


def plot_hist_gauss(col, ax=None, orientation='vertical', label=''):
  show = True if ax is None else False

  ax = col.plot.hist(orientation=orientation, density=True,
                     alpha=0.2, ax=ax, subplots=False)

  hist, bin_edges = np.histogram(col, density=True)
  bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

  def gauss(x, A, mu, sigma):
    return A * np.exp(-(x - mu)**2 / (2. * sigma**2))

  p0 = [1, 5, 1]
  popt, pcov = curve_fit(gauss, bin_centers, hist, p0=p0)  # hist
  A, mu, sigma = popt

  granularity = 100
  x = np.linspace(col.min(), col.max(), granularity)
  if orientation == 'horizontal':
    ax.plot(gauss(x, *popt), x, c='C0', label='Fitted data')
    ax.hlines(mu, *ax.get_xlim(), colors='C3', label='Fitted mean')
    ax.set_ylabel(label)
  else:
    ax.plot(x, gauss(x, *popt), c='C0', label='Fitted data')
    ax.vlines(mu, *ax.get_ylim(), colors='C3', label='Fitted mean')
    ax.set_xlabel(label)

  if show:
    plt.show()

  return popt


def plot_gene(gene, ax=None, orientation='vertical'):
  plot_hist_gauss(gene_table[gene], ax=ax,
                  orientation=orientation, label=gene)

# ===============================================================================================
# ===============================================================================================
# ===============================================================================================


def plot_genes(x_gene=None, y_gene=None, ax=None, mode='raw', gene_table=gene_table):
  """
  Produces a scatterplot of the TPM (Transcriptions Per Million) of two genes,
  and fits data to bivariate Gaussian which is also plotted.

  Parameters
  ----------
  x_gene : str
    The common name of the gene to be plotted along the x-axis.
  y_gene : str
    The common name of the gene to be plotted along the y-axis.
  ax : matplotlib axes object, default None
    An axes of the current figure.
  mode : str, default 'raw'
    The mode of plotting:

    - 'raw' : plot data as is
    - 'norm' : normalize and recenter before plotting

  gene_table : pandas DataFrame, default global gene_table
    A table containing the two genes to be plotted as columns

  Returns
  -------
  plotted_data : pandas DataFrame
    The two columns of data that were actually plotted
  A : float
    Amplitude of optimal bivariate Gaussian
  x0 : float
    x mean of optimal bivariate Gaussian
  y0 : float
    y mean of optimal bivariate Gaussian
  sigma_x : float
    Standard deviation along x axis of optimal bivariate Gaussian
  sigma_y : float
    Standard deviation along y axis of optimal bivariate Gaussian
  rho : float
    Pearson correlation coefficient of optimal bivariate Gaussian
  z_offset : float
    Additive offset of optimal bivariate Gaussian
  """

  show = True if ax is None else False
  if ax is None:
    ax = plt.axes()
  ax.set_aspect('equal', adjustable='box')
  if x_gene is not None and y_gene is not None:
    two_cols = gene_table.loc[:, [x_gene, y_gene]]
  else:  # testing
    print('WARNING: plot_genes requires two gene names as input. '
          'You have omitted at least one, so random test data will '
          'be plotted instead.')
    x_gene, y_gene = 'x', 'y'
    test_dist = np.random.default_rng().multivariate_normal(
        mean=[100, 200], cov=[[1, 0.9], [0.9, np.sqrt(3)]], size=(1000))
    two_cols = pd.DataFrame(data=test_dist, columns=[x_gene, y_gene])

  # Mean and density ---------------------------------------------------------

  mean = two_cols.mean()

  data_for_kde = two_cols.values.T
  density_estimator = gaussian_kde(data_for_kde)
  z = density_estimator(data_for_kde)

  # Fit to 2D Gaussian =======================================================

  def bivariate_Gaussian(xy, A, x0, y0, sigma_x, sigma_y, rho, z_offset):
    x, y = xy

    # A should really be divided by (2*np.pi*sigma_x*sigma_y*np.sqrt(1-rho**2))
    a = 1 / (2 * (1 - rho**2) * sigma_x**2)
    b = - rho / ((1 - rho**2) * sigma_x * sigma_y)
    c = 1 / (2 * (1 - rho**2) * sigma_y**2)
    g = z_offset + A * \
        np.exp(-(a * (x - x0)**2 + b * (x - x0) * (y - y0) + c * (y - y0)**2))

    return g.ravel()

  gran = 400  # granularity
  x = np.linspace(two_cols[x_gene].min(), two_cols[x_gene].max(), gran)
  y = np.linspace(two_cols[y_gene].min(), two_cols[y_gene].max(), gran)
  pts = np.transpose(np.dstack(np.meshgrid(x, y)),
                     axes=[2, 0, 1]).reshape(2, -1)

  p0 = (1, mean[0], mean[1], 1, 1, 0, 0)
  popt, pcov = curve_fit(bivariate_Gaussian, pts,
                         density_estimator(pts), p0=p0)
  A, x0, y0, sigma_x, sigma_y, rho, z_offset = popt

  cov = np.array(
      [[sigma_x**2, rho * sigma_x * sigma_y],
       [rho * sigma_x * sigma_y, sigma_y**2]])
  eigenvalues, eigenvectors = np.linalg.eig(cov)
  # eigvals are variances along ellipse axes, eigvects are direction of axes
  scaled_eigvects = np.sqrt(eigenvalues) * eigenvectors

  # Plots ====================================================================

  plotted_data = gene_table

  if mode == 'raw':
    # --- Plot Data ---
    two_cols.plot.scatter(x=x_gene, y=y_gene, c=z,
                          s=2, ylabel=y_gene, ax=ax)

    # --- Plot Fitted Gaussian ---
    pts = pts.reshape(2, gran, gran)
    data_fitted = bivariate_Gaussian(pts, *popt).reshape(gran, gran)

    # contour
    ax.contour(pts[0], pts[1], data_fitted, 8,
               cmap='viridis', zorder=0, alpha=.5)

    # center
    ax.plot(x0, y0, 'rx')

    # gene axes
    ax.quiver([x0, x0], [y0, y0], [1, 0], [0, 1], angles='xy', scale_units='xy',
              width=0.005, scale=1, color=['magenta', 'violet'], alpha=0.35)

    # ellipse axes
    ax.quiver([x0, x0], [y0, y0], *scaled_eigvects, angles='xy', scale_units='xy',
              width=0.005, scale=1, color=['red', 'firebrick'], alpha=0.35)

    plotted_data = two_cols

  # --------------------------------------------------------------------------

  elif mode == 'norm':
    inv_cov = np.linalg.inv(scaled_eigvects)
    recentered_data = two_cols.values - [x0, y0]
    normed_data = recentered_data @ inv_cov.T
    normed_two_cols = pd.DataFrame(
        data=normed_data, columns=[x_gene, y_gene])

    # --- Plot Data ---
    normed_two_cols.plot.scatter(x=x_gene, y=y_gene, c=z, s=2, ax=ax,
                                 xlabel='minor axis',
                                 ylabel='major axis')

    # --- Plot Fitted Gaussian ---
    x = np.linspace(normed_two_cols[x_gene].min(),
                    normed_two_cols[x_gene].max(), gran)
    y = np.linspace(normed_two_cols[y_gene].min(),
                    normed_two_cols[y_gene].max(), gran)
    pts = np.transpose(np.dstack(np.meshgrid(x, y)), axes=[2, 0, 1])

    pts = pts.reshape(2, gran, gran)
    data_fitted = bivariate_Gaussian(pts, A, 0, 0, 1, 1, 0, z_offset)
    data_fitted = data_fitted.reshape(gran, gran)

    # contour
    ax.contour(pts[0], pts[1], data_fitted, 8,
               cmap='viridis', zorder=0, alpha=.5)

    # center
    ax.plot(0, 0, 'rx')

    # gene axes
    ax.quiver([0, 0], [0, 0], *inv_cov, angles='xy', scale_units='xy',
              width=0.005, scale=1, color=['magenta', 'violet'], alpha=0.35)

    # ellipse axes
    ax.quiver([0, 0], [0, 0], [1, 0], [0, 1], angles='xy', scale_units='xy',
              width=0.005, scale=1, color=['red', 'firebrick'], alpha=0.35)

    plotted_data = normed_two_cols

  # ==========================================================================

  if show:
    plt.show()

  return (plotted_data,
          A, x0, y0, sigma_x, sigma_y, rho, z_offset)  # optimal gaussian params

# ===============================================================================================
# ===============================================================================================
# ===============================================================================================


def plot_scatter_hist(x_gene, y_gene, mode='raw'):
  fig = plt.figure(layout='constrained')
  ax = fig.add_gridspec(top=0.75, right=0.75).subplots()
  # ax.set_aspect('equal', adjustable='box') # ax.set(aspect=1)
  ax_histx = ax.inset_axes([0, 1.05, 1, 0.25], sharex=ax)
  ax_histy = ax.inset_axes([1.05, 0, 0.25, 1], sharey=ax)

  ax_histx.tick_params(axis="x", labelbottom=False)
  ax_histy.tick_params(axis="y", labelleft=False)

  plot_result = plot_genes(x_gene, y_gene, ax=ax, mode=mode)
  plotted_data = plot_result[0]
  x_A, x_mu, x_sigma = plot_hist_gauss(plotted_data[x_gene], ax=ax_histx)
  y_A, y_mu, y_sigma = plot_hist_gauss(plotted_data[y_gene], ax=ax_histy,
                                       orientation='horizontal')
  ax_histx.set_ylabel('Freq')
  ax_histy.set_xlabel('Freq')

  ax.vlines(x_mu, *ax.get_ylim(),
            label=f'{x_gene} mean', colors='C3', zorder=0)
  ax.hlines(y_mu, *ax.get_xlim(),
            label=f'{y_gene} mean', colors='C3', zorder=0)

  # ax.fill_between([plotted_data[x_gene].min(), plotted_data[x_gene].max()],
  #                 *ax.get_ylim(), color='C0', alpha=0.01, lw=0)
  # ax.fill_betweenx([plotted_data[y_gene].min(), plotted_data[y_gene].max()],
  #                 *ax.get_xlim(), color='C0', alpha=0.01, lw=0)


def create_correct_gene_plot(genes, mode):
  if len(genes) == 0:
    raise gr.Error("Please select at least one gene to plot.")
  elif len(genes) == 1:
    plot_gene(gene)
  elif len(genes) == 2:
    mode = 'norm' if mode else None
    plot_scatter_hist(genes, mode)
  else:
    raise gr.Error("Cannot plot more than two genes at a time.")

  fig = plt.gcf()

  return PIL.Image.frombytes(
      'RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())


demo = gr.Interface(
    create_correct_gene_plot,
    [
        gr.Dropdown(
            gene_table.columns, value=["APP", "PSENEN"], multiselect=True, label="Genes", info="Select one or two genes to plot."
        ),
        gr.Checkbox(label="Normalize",
                    info="Recenter and normalize the Gaussian for two genes."),
    ],
    "image",
    # examples=[
    #     [2, "cat", ["Japan", "Pakistan"], "park", ["ate", "swam"], True],
    #     [4, "dog", ["Japan"], "zoo", ["ate", "swam"], False],
    #     [10, "bird", ["USA", "Pakistan"], "road", ["ran"], False],
    #     [8, "cat", ["Pakistan"], "zoo", ["ate"], True],
    # ]
)

if __name__ == "__main__":
  demo.launch()