EPCOT / util.py
drjieliu's picture
Upload 4 files
a61d226
import argparse
import numpy as np
import pyBigWig,os
from zipfile import ZipFile
import zipfile
import shutil
import torch
from pretrain.model import build_epd_model
from pretrain.track.model import build_track_model
from cage.model import build_cage_model
from cop.micro_model import build_microc_model
from cop.hic_model import build_hic_model
from einops import rearrange
import gradio
def parser_args():
"""
Hyperparameters for the pre-training model
"""
# add_help = False
parser = argparse.ArgumentParser(add_help = False)
parser.add_argument('--num_class', default=245, type=int,help='the number of epigenomic features to be predicted')
parser.add_argument('--seq_length', default=1600, type=int,help='the length of input sequences')
parser.add_argument('--nheads', default=4, type=int)
parser.add_argument('--hidden_dim', default=512, type=int)
parser.add_argument('--dim_feedforward', default=1024, type=int)
parser.add_argument('--enc_layers', default=1, type=int)
parser.add_argument('--dec_layers', default=2, type=int)
parser.add_argument('--dropout', default=0.2, type=float)
args, unknown = parser.parse_known_args()
return args,parser
def get_args():
args,_ = parser_args()
return args,_
def parser_args_epi(parent_parser):
"""
Hyperparameters for the downstream model to predict 1kb-resolution CAGE-seq
"""
parser=argparse.ArgumentParser(parents=[parent_parser],add_help = False)
parser.add_argument('--bins', type=int, default=500)
parser.add_argument('--crop', type=int, default=10)
parser.add_argument('--embed_dim', default=768, type=int)
parser.add_argument('--return_embed', default=False, action='store_true')
args, unknown = parser.parse_known_args()
return args
def parser_args_cage(parent_parser):
"""
Hyperparameters for the downstream model to predict 1kb-resolution CAGE-seq
"""
parser=argparse.ArgumentParser(parents=[parent_parser],add_help = False)
parser.add_argument('--bins', type=int, default=500)
parser.add_argument('--crop', type=int, default=10)
parser.add_argument('--embed_dim', default=768, type=int)
parser.add_argument('--return_embed', default=True, action='store_false')
args, unknown = parser.parse_known_args()
return args
def parser_args_hic(parent_parser):
"""
Hyperparameters for the downstream model to predict 5kb-resolution Hi-C and ChIA-PET
"""
parser=argparse.ArgumentParser(parents=[parent_parser],add_help = False)
parser.add_argument('--bins', type=int, default=200)
parser.add_argument('--crop', type=int, default=4)
parser.add_argument('--embed_dim', default=256, type=int)
args, unknown = parser.parse_known_args()
return args
def parser_args_microc(parent_parser):
"""
Hyperparameters for the downstream model to predict 1kb-resolution Micro-C
"""
parser=argparse.ArgumentParser(parents=[parent_parser],add_help = False)
parser.add_argument('--bins', type=int, default=500)
parser.add_argument('--crop', type=int, default=10)
parser.add_argument('--embed_dim', default=768, type=int)
parser.add_argument('--return_embed', default=True, action='store_false')
args, unknown = parser.parse_known_args()
return args
def check_region(chrom,start,end,ref_genome,region_len):
start,end=int(start),int(end)
if end-start != region_len:
if region_len==500000:
raise gradio.Error("Please enter a 500kb region!")
else:
raise gradio.Error("Please enter a 1Mb region!")
if start<300 or end > ref_genome.shape[1]-300:
raise gradio.Error("The start of input region should be greater than 300 and "
"the end of the region should be less than %s!"%(ref_genome.shape[1]-300))
return int(chrom),start,end
def generate_input(start,end,ref_genome,atac_seq):
# inputs=[]
pad_left=np.expand_dims(np.vstack((ref_genome[:,start-300:start],atac_seq[:,start-300:start])),0)
pad_right=np.expand_dims(np.vstack((ref_genome[:,end:end+300],atac_seq[:,end:end+300])),0)
center=np.vstack((ref_genome[:,start:end],atac_seq[:,start:end]))
center=rearrange(center,'n (b l)-> b n l',l=1000)
dmatrix = np.concatenate((pad_left, center[:, :, -300:]), axis=0)[:-1, :, :]
umatrix = np.concatenate((center[:, :, :300], pad_right), axis=0)[1:, :, :]
return np.concatenate((dmatrix, center, umatrix), axis=2)
def search_tf(tf):
with open('data/epigenomes.txt', 'r') as f:
epigenomes = f.read().splitlines()
tf_idx= epigenomes.index(tf)
return tf_idx
def predict_epb(
model_path,
region, ref_genome,atac_seq,
device,
cop_type
):
args, parser = get_args()
pretrain_model = build_epd_model(args)
pretrain_model.load_state_dict(torch.load(model_path,map_location=torch.device(device)))
pretrain_model.eval()
pretrain_model.to(device)
start,end=region
inputs=generate_input(start,end,ref_genome,atac_seq)
inputs=torch.tensor(inputs).float().to(device)
with torch.no_grad():
pred_epi=torch.sigmoid(pretrain_model(inputs)).detach().cpu().numpy()
if cop_type == 'Micro-C':
return pred_epi[10:-10,:]
else:
return pred_epi[20:-20,:]
def predict_epis(
model_path,
region, ref_genome,atac_seq,
device,
cop_type
):
args, parser = get_args()
epi_args = parser_args_epi(parser)
pretrain_model = build_track_model(epi_args)
pretrain_model.load_state_dict(torch.load(model_path,map_location=torch.device(device)))
pretrain_model.eval()
pretrain_model.to(device)
inputs=[]
start,end=region
if cop_type == 'Micro-C':
inputs.append(generate_input(start,end,ref_genome,atac_seq))
else:
for loc in range(start+20000,end-20000,480000):
inputs.append(generate_input(loc-10000,loc+490000,ref_genome,atac_seq))
inputs=np.stack(inputs)
inputs=torch.tensor(inputs).float().to(device)
pred_epi=[]
with torch.no_grad():
for i in range(inputs.shape[0]):
pred_epi.append(pretrain_model(inputs[i:i+1]).detach().cpu().numpy())
out_epi = rearrange(np.vstack(pred_epi), 'i j k -> (i j) k')
return out_epi
def predict_cage(
model_path,
region, ref_genome, atac_seq,
device,
cop_type
):
args, parser = get_args()
cage_args = parser_args_cage(parser)
cage_model=build_cage_model(cage_args)
cage_model.load_state_dict(torch.load(model_path,map_location=torch.device(device)))
cage_model.eval()
cage_model.to(device)
inputs = []
start, end = region
if cop_type == 'Micro-C':
inputs.append(generate_input(start, end, ref_genome, atac_seq))
else:
for loc in range(start + 20000, end - 20000, 480000):
inputs.append(generate_input(loc - 10000, loc + 490000, ref_genome, atac_seq))
inputs = np.stack(inputs)
inputs = torch.tensor(inputs).float().to(device)
pred_cage = []
with torch.no_grad():
for i in range(inputs.shape[0]):
pred_cage.append(cage_model(inputs[i:i + 1]).detach().cpu().numpy().squeeze())
return np.concatenate(pred_cage)
def arraytouptri(arrays,args):
effective_lens=args.bins-2*args.crop
triu_tup = np.triu_indices(effective_lens)
temp=np.zeros((effective_lens,effective_lens))
temp[triu_tup]=arrays
return temp
def complete_mat(mat):
temp = mat.copy()
np.fill_diagonal(temp,0)
mat= mat+temp.T
return mat
def predict_hic(
model_path,
region, ref_genome,atac_seq,
device
):
args, parser = get_args()
hic_args = parser_args_hic(parser)
hic_model = build_hic_model(hic_args)
hic_model.load_state_dict(torch.load(model_path,map_location=torch.device(device)))
hic_model.eval()
hic_model.to(device)
start,end=region
inputs=np.stack([generate_input(start,end,ref_genome,atac_seq)])
inputs=torch.tensor(inputs).float().to(device)
with torch.no_grad():
temp=hic_model(inputs).detach().cpu().numpy().squeeze()
return np.stack([complete_mat(arraytouptri(temp[:,i], hic_args)) for i in range(temp.shape[-1])])
def predict_microc(
model_path,
region, ref_genome,atac_seq,
device
):
args, parser = get_args()
microc_args = parser_args_microc(parser)
microc_model = build_microc_model(microc_args)
microc_model.load_state_dict(torch.load(model_path,map_location=torch.device(device)))
microc_model.eval()
microc_model.to(device)
start,end=region
inputs=np.stack([generate_input(start,end,ref_genome,atac_seq)])
inputs=torch.tensor(inputs).float().to(device)
with torch.no_grad():
temp=microc_model(inputs).detach().cpu().numpy().squeeze()
return complete_mat(arraytouptri(temp, microc_args))
def filetobrowser(out_epis,out_cages,out_cop,chrom,start,end,file_id):
with open('data/epigenomes.txt', 'r') as f:
epigenomes = f.read().splitlines()
files_to_zip = file_id
if os.path.exists(files_to_zip):
shutil.rmtree(files_to_zip)
os.mkdir(files_to_zip)
hdr=[]
with open('data/chrom_size_hg38.txt', 'r') as f:
for line in f:
tmp=line.strip().split('\t')
hdr.append((tmp[0],int(tmp[1])))
for i in range(out_epis.shape[1]):
bwfile = pyBigWig.open(os.path.join(files_to_zip,"%s.bigWig"%epigenomes[i]), 'w')
bwfile.addHeader(hdr)
bwfile.addEntries(['chr' + str(chrom)]*out_epis.shape[0],[loc for loc in range(start,end,1000)],
ends=[loc+1000 for loc in range(start,end,1000)],values=out_epis[:,i].tolist())
bwfile.close()
bwfile = pyBigWig.open(os.path.join(files_to_zip,"cage.bigWig"),'w')
bwfile.addHeader(hdr)
bwfile.addEntries(['chr' + str(chrom)] * out_cages.shape[0], [loc for loc in range(start, end, 1000)],
ends=[loc + 1000 for loc in range(start, end, 1000)], values=out_cages.tolist())
bwfile.close()
cop_lines=[]
interval=1000 if out_cop.shape[-1]==480 else 5000
if out_cop.shape[-1]==480:
for bin1 in range(out_cop.shape[-1]):
for bin2 in range(bin1,out_cop.shape[-1],1):
# tmp=['chr' + str(chrom),str(start+bin1*interval),str(start+(bin1+1)*interval),'chr' + str(chrom),
# str(start + bin2 * interval), str(start + (bin2 + 1) * interval),'.',str(np.around(out_cop[bin1,bin2],2)),'.','.'
# ]
tmp = ['0', 'chr' + str(chrom), str(start + bin1 * interval), '0', '0', 'chr' + str(chrom),
str(start + bin2 * interval), '1', str(np.around(out_cop[bin1, bin2], 2))]
cop_lines.append('\t'.join(tmp)+'\n')
with open(os.path.join(files_to_zip,"microc.bedpe"),'w') as f:
f.writelines(cop_lines)
else:
types=['CTCF_ChIA-PET','POLR2_ChIA-PET','Hi-C']
for i in range(len(types)):
for bin1 in range(out_cop.shape[-1]):
for bin2 in range(bin1, out_cop.shape[-1], 1):
tmp=['0','chr' + str(chrom), str(start + bin1 * interval),'0','0','chr' +str(chrom),str(start + bin2 * interval),'1',str(np.around(out_cop[i,bin1, bin2], 2))]
cop_lines.append('\t'.join(tmp) + '\n')
with open(os.path.join(files_to_zip,"%s.bedpe"%types[i]), 'w') as f:
f.writelines(cop_lines)
out_zipfile = ZipFile("results/formatted_%s.zip" % file_id, "w", zipfile.ZIP_DEFLATED)
for file_to_zip in os.listdir(files_to_zip):
file_to_zip_full_path = os.path.join(files_to_zip, file_to_zip)
out_zipfile.write(filename=file_to_zip_full_path, arcname=file_to_zip)
out_zipfile.close()
shutil.rmtree(files_to_zip)
return "results/formatted_%s.zip"%file_id