Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import scipy | |
| from scipy.sparse import tril, triu | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from tensorflow.keras.models import model_from_json | |
| from huggingface_hub import hf_hub_download | |
| #input_file = hf_hub_download(repo_id="dylanplummer/hicorr", filename="arima_beta.chr22", repo_type="dataset", token=os.environ['DATASET_SECRET']) | |
| input_file = hf_hub_download(repo_id="dylanplummer/hicorr", filename="ORC2.chr22", repo_type="dataset", token=os.environ['DATASET_SECRET']) | |
| data_dir = 'data/' | |
| sparse_data_dir = 'data/sparse_data/' | |
| def get_chromosome_from_filename(filename): | |
| """ | |
| Extract the chromosome string from any of the file name formats we use | |
| Args: | |
| filename (:obj:`str`) : name of anchor to anchor file | |
| Returns: | |
| Chromosome string of form chr<> | |
| """ | |
| chr_index = filename.find('chr') # index of chromosome name | |
| if chr_index == 0: # if chromosome name is file prefix | |
| return filename[:filename.find('.')] | |
| file_ending_index = filename.rfind('.') # index of file ending | |
| if chr_index > file_ending_index: # if chromosome name is file ending | |
| return filename[chr_index:] | |
| else: | |
| return filename[chr_index: file_ending_index] | |
| def draw_heatmap(matrix, color_scale, ax=None, return_image=False): | |
| """ | |
| Display ratio heatmap containing only strong signals (values > 1 or 0.98th quantile) | |
| Args: | |
| matrix (:obj:`numpy.array`) : ratio matrix to be displayed | |
| color_scale (:obj:`int`) : max ratio value to be considered strongest by color mapping | |
| ax (:obj:`matplotlib.axes.Axes`) : axes which will contain the heatmap. If None, new axes are created | |
| return_image (:obj:`bool`) : set to True to return the image obtained from drawing the heatmap with the generated color map | |
| Returns: | |
| ``numpy.array`` : if ``return_image`` is set to True, return the heatmap as an array | |
| """ | |
| if color_scale != 0: | |
| breaks = np.append(np.arange(1.001, color_scale, (color_scale - 1.001) / 18), np.max(matrix)) | |
| elif np.max(matrix) < 2: | |
| breaks = np.arange(1.001, np.max(matrix), (np.max(matrix) - 1.001) / 19) | |
| else: | |
| step = (np.quantile(matrix, q=0.95) - 1) / 18 | |
| up = np.quantile(matrix, q=0.95) + 0.011 | |
| if up < 2: | |
| up = 2 | |
| step = 0.999 / 18 | |
| breaks = np.append(np.arange(1.001, up, step), np.max(matrix)) | |
| n_bin = 20 # Discretizes the interpolation into bins | |
| colors = ["#FFFFFF", "#FFE4E4", "#FFD7D7", "#FFC9C9", "#FFBCBC", "#FFAEAE", "#FFA1A1", "#FF9494", "#FF8686", | |
| "#FF7979", "#FF6B6B", "#FF5E5E", "#FF5151", "#FF4343", "#FF3636", "#FF2828", "#FF1B1B", "#FF0D0D", | |
| "#FF0000"] | |
| cmap_name = 'my_list' | |
| # Create the colormap | |
| cm = matplotlib.colors.LinearSegmentedColormap.from_list( | |
| cmap_name, colors, N=n_bin) | |
| norm = matplotlib.colors.BoundaryNorm(breaks, 20) | |
| # Fewer bins will result in "coarser" colomap interpolation | |
| if ax is None: | |
| _, ax = plt.subplots() | |
| img = ax.imshow(matrix, cmap=cm, norm=norm, interpolation='nearest') | |
| if return_image: | |
| plt.close() | |
| return img.get_array() | |
| def anchor_list_to_dict(anchors): | |
| """ | |
| Converts the array of anchor names to a dictionary mapping each anchor to its chromosomal index | |
| Args: | |
| anchors (:obj:`numpy.array`) : array of anchor name values | |
| Returns: | |
| `dict` : dictionary mapping each anchor to its index from the array | |
| """ | |
| anchor_dict = {} | |
| for i, anchor in enumerate(anchors): | |
| anchor_dict[anchor] = i | |
| return anchor_dict | |
| def anchor_to_locus(anchor_dict): | |
| """ | |
| Function to convert an anchor name to its genomic locus which can be easily vectorized | |
| Args: | |
| anchor_dict (:obj:`dict`) : dictionary mapping each anchor to its chromosomal index | |
| Returns: | |
| `function` : function which returns the locus of an anchor name | |
| """ | |
| def f(anchor): | |
| return anchor_dict[anchor] | |
| return f | |
| def load_chr_ratio_matrix_from_sparse(dir_name, file_name, anchor_dir, sparse_dir=None, anchor_list=None, chr_name=None, dummy=5, ignore_sparse=False, force_symmetry=True, use_raw=False): | |
| """ | |
| Loads data as a sparse matrix by either reading a precompiled sparse matrix or an anchor to anchor file which is converted to sparse CSR format. | |
| Ratio values are computed using the observed (obs) and expected (exp) values: | |
| .. math:: | |
| ratio = \\frac{obs + dummy}{exp + dummy} | |
| Args: | |
| dir_name (:obj:`str`) : directory containing the anchor to anchor or precompiled (.npz) sparse matrix file | |
| file_name (:obj:`str`) : name of anchor to anchor or precompiled (.npz) sparse matrix file | |
| anchor_dir (:obj:`str`) : directory containing the reference anchor ``.bed`` files | |
| dummy (:obj:`int`) : dummy value to used when computing ratio values | |
| ignore_sparse (:obj:`bool`) : set to True to ignore precompiled sparse matrices even if they exist | |
| Returns: | |
| ``scipy.sparse.csr_matrix``: sparse matrix of ratio values | |
| """ | |
| global data_dir | |
| global sparse_data_dir | |
| if chr_name is None: | |
| chr_name = get_chromosome_from_filename(file_name) | |
| sparse_rep_dir = dir_name[dir_name[: -1].rfind('/') + 1:] # directory where the pre-compiled sparse matrices are saved | |
| if sparse_dir is not None: | |
| sparse_data_dir = sparse_dir | |
| os.makedirs(os.path.join(sparse_data_dir, sparse_rep_dir), exist_ok=True) | |
| if file_name.endswith('.npz'): # loading pre-combined and pre-compiled sparse data | |
| sparse_matrix = scipy.sparse.load_npz(dir_name + file_name) | |
| else: # load from file name | |
| if file_name + '.npz' in os.listdir(os.path.join(sparse_data_dir, sparse_rep_dir)) and not ignore_sparse: # check if pre-compiled data already exists | |
| sparse_matrix = scipy.sparse.load_npz(os.path.join(sparse_data_dir, sparse_rep_dir, file_name + '.npz')) | |
| else: # otherwise generate sparse matrix from anchor2anchor file and save pre-compiled data | |
| if anchor_list is None: | |
| if anchor_dir is None: | |
| assert 'You must supply either an anchor reference list or the directory containing one' | |
| anchor_list = pd.read_csv(os.path.join(anchor_dir, '%s.bed' % chr_name), sep='\t', | |
| names=['chr', 'start', 'end', 'anchor']) # read anchor list file | |
| matrix_size = len(anchor_list) # matrix size is needed to construct sparse CSR matrix | |
| anchor_dict = anchor_list_to_dict(anchor_list['anchor'].values) # convert to anchor --> index dictionary | |
| try: # first try reading anchor to anchor file as <a1> <a2> <obs> <exp> | |
| chr_anchor_file = pd.read_csv( | |
| os.path.join(dir_name, file_name), | |
| delimiter='\t', | |
| names=['anchor1', 'anchor2', 'obs', 'exp'], | |
| usecols=['anchor1', 'anchor2', 'obs', 'exp']) # read chromosome anchor to anchor file | |
| rows = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor1'].values) # convert anchor names to row indices | |
| cols = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor2'].values) # convert anchor names to column indices | |
| ratio = (chr_anchor_file['obs'] + dummy) / (chr_anchor_file['exp'] + dummy) # compute matrix ratio value | |
| sparse_matrix = scipy.sparse.csr_matrix((ratio, (rows, cols)), shape=(matrix_size, matrix_size)) # construct sparse CSR matrix | |
| except: # otherwise read anchor to anchor file as <a1> <a2> <ratio> | |
| chr_anchor_file = pd.read_csv( | |
| os.path.join(dir_name, file_name), | |
| delimiter='\t', | |
| names=['anchor1', 'anchor2', 'ratio'], | |
| usecols=['anchor1', 'anchor2', 'ratio']) | |
| rows = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor1'].values) # convert anchor names to row indices | |
| cols = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor2'].values) # convert anchor names to column indices | |
| if use_raw: | |
| sparse_matrix = scipy.sparse.csr_matrix((chr_anchor_file['obs'], (rows, cols)), shape=( | |
| matrix_size, matrix_size)) # construct sparse CSR matrix | |
| else: | |
| sparse_matrix = scipy.sparse.csr_matrix((chr_anchor_file['ratio'], (rows, cols)), shape=(matrix_size, matrix_size)) # construct sparse CSR matrix | |
| if force_symmetry: | |
| upper_sum = triu(sparse_matrix, k=1).sum() | |
| lower_sum = tril(sparse_matrix, k=-1).sum() | |
| if upper_sum == 0 or lower_sum == 0: | |
| sparse_matrix = sparse_matrix + sparse_matrix.transpose() | |
| sparse_triu = scipy.sparse.triu(sparse_matrix) | |
| sparse_matrix = sparse_triu + sparse_triu.transpose() | |
| if not ignore_sparse: | |
| scipy.sparse.save_npz(os.path.join(sparse_data_dir, sparse_rep_dir, file_name), sparse_matrix) # save precompiled data | |
| return sparse_matrix | |
| model_depths = ['1.5M', '2M', '2.4M', '4.88M', '5M', '6.29M', '8.5M', '12.5M', '16.5M', '25M', '32M', '50M', '100M', '150M'] | |
| # Load the model | |
| model_weights = 'DeepLoop_models/CPGZ_trained/12.5M.h5' # Replace with your model weights file | |
| model_architecture = 'DeepLoop_models/CPGZ_trained/12.5M.json' # Replace with your model architecture file | |
| with open(model_architecture, 'r') as f: | |
| model = model_from_json(f.read()) | |
| model.load_weights(model_weights) | |
| # Define the anchor file path | |
| anchor_file = 'ref/hg19_DPNII_anchor_bed/chr22.bed' | |
| #anchor_file = 'ref/hg19_Arima_anchor_bed/chr22.bed' | |
| # Define the tile size | |
| tile_size = 128 | |
| # Load the input matrix | |
| # input_file = '../anchor_2_anchor.loop.chr22' | |
| input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file), | |
| os.path.dirname(anchor_file), force_symmetry=True) | |
| # input_file = None | |
| # input_matrix = None | |
| # Load the anchor list | |
| anchor_list = pd.read_csv(anchor_file, sep='\t', names=['chr', 'start', 'end', 'anchor']) | |
| def predict(depth_idx): | |
| """Loads the input file, predicts the output, and visualizes the tile.""" | |
| selected_depth = model_depths[depth_idx] | |
| model_weights = f'DeepLoop_models/CPGZ_trained/{selected_depth}.h5' # Replace with your model weights file | |
| model_architecture = f'DeepLoop_models/CPGZ_trained/{selected_depth}.json' # Replace with your model architecture file | |
| with open(model_architecture, 'r') as f: | |
| model = model_from_json(f.read()) | |
| model.load_weights(model_weights) | |
| # Get the tile | |
| center_anchor = int(len(anchor_list) / 2) | |
| i = max(0, center_anchor - int(tile_size / 2)) | |
| j = i + tile_size | |
| tile = input_matrix[i:j, i:j].toarray() | |
| tile = np.expand_dims(tile, -1) | |
| tile = np.expand_dims(tile, 0) | |
| # Predict the output | |
| denoised_tile = model.predict(tile).reshape((tile_size, tile_size)) | |
| denoised_tile[denoised_tile < 0] = 0 | |
| # Normalize the tiles | |
| tile = tile[0, ..., 0] | |
| denoised_tile = (denoised_tile + denoised_tile.T) / 2 | |
| # Visualize the tiles | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) | |
| draw_heatmap(tile, 0, ax=ax1) | |
| draw_heatmap(denoised_tile, 0, ax=ax2) | |
| ax1.set_title('Input Tile') | |
| ax2.set_title(f'{selected_depth} model') | |
| plt.tight_layout() | |
| # return as a numpy array | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| plt.close(fig) | |
| return data | |
| def upload_file(file): | |
| global input_file, input_matrix | |
| print(file) | |
| input_file = file | |
| input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file), | |
| os.path.dirname(anchor_file), force_symmetry=True) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| upload = gr.UploadButton("Upload a file", file_count="single") | |
| with gr.Row(): | |
| slider = gr.Slider(minimum=0, maximum=len(model_depths) - 1, step=1, label='Model Depth', interactive=True) | |
| heatmap = gr.Image(label='Visualization') | |
| upload.upload(upload_file, upload) | |
| slider.change(predict, [slider], heatmap) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |