Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from keras import regularizers | |
from keras.layers import Input, Dense, Dropout, Activation, Conv1D | |
from keras.layers import GlobalAveragePooling1D, AveragePooling1D | |
from keras.layers import Bidirectional, LSTM | |
from keras import Model | |
from keras.metrics import MeanSquaredError | |
import pandas as pd | |
import numpy as np | |
import requests | |
from functools import reduce | |
from operator import add | |
import tabulate | |
from difflib import SequenceMatcher | |
import cyvcf2 | |
import parasail | |
import re | |
ntmap = {'A': (1, 0, 0, 0), | |
'C': (0, 1, 0, 0), | |
'G': (0, 0, 1, 0), | |
'T': (0, 0, 0, 1) | |
} | |
def get_seqcode(seq): | |
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1)) | |
def BiLSTM_model(input_shape): | |
input = Input(shape=input_shape) | |
conv1 = Conv1D(128, 5, activation="relu")(input) | |
pool1 = AveragePooling1D(2)(conv1) | |
drop1 = Dropout(0.1)(pool1) | |
conv2 = Conv1D(128, 5, activation="relu")(drop1) | |
pool2 = AveragePooling1D(2)(conv2) | |
drop2 = Dropout(0.1)(pool2) | |
lstm1 = Bidirectional(LSTM(128, | |
dropout=0.1, | |
activation='tanh', | |
return_sequences=True, | |
kernel_regularizer=regularizers.l2(1e-4)))(drop2) | |
avgpool = GlobalAveragePooling1D()(lstm1) | |
dense1 = Dense(128, | |
kernel_regularizer=regularizers.l2(1e-4), | |
bias_regularizer=regularizers.l2(1e-4), | |
activation="relu")(avgpool) | |
drop3 = Dropout(0.1)(dense1) | |
dense2 = Dense(32, | |
kernel_regularizer=regularizers.l2(1e-4), | |
bias_regularizer=regularizers.l2(1e-4), | |
activation="relu")(drop3) | |
drop4 = Dropout(0.1)(dense2) | |
dense3 = Dense(32, | |
kernel_regularizer=regularizers.l2(1e-4), | |
bias_regularizer=regularizers.l2(1e-4), | |
activation="relu")(drop4) | |
drop5 = Dropout(0.1)(dense3) | |
output = Dense(1, activation="linear")(drop5) | |
model = Model(inputs=[input], outputs=[output]) | |
return model | |
def fetch_ensembl_transcripts(gene_symbol): | |
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
gene_data = response.json() | |
if 'Transcript' in gene_data: | |
return gene_data['Transcript'] | |
else: | |
print("No transcripts found for gene:", gene_symbol) | |
return None | |
else: | |
print(f"Error fetching gene data from Ensembl: {response.text}") | |
return None | |
def fetch_ensembl_sequence(transcript_id): | |
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
sequence_data = response.json() | |
if 'seq' in sequence_data: | |
return sequence_data['seq'] | |
else: | |
print("No sequence found for transcript:", transcript_id) | |
return None | |
else: | |
print(f"Error fetching sequence data from Ensembl: {response.text}") | |
return None | |
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34): | |
targets = [] | |
len_sequence = len(sequence) | |
#complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'} | |
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'} | |
for i in range(len_sequence - target_length + 1): | |
target_seq = sequence[i:i + target_length] | |
if target_seq[4:7] == 'TTT': | |
if strand == -1: | |
tar_start = end - i - target_length + 1 | |
tar_end = end -i | |
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1] | |
else: | |
tar_start = start + i | |
tar_end = start + i + target_length - 1 | |
#seq_in_ref = target_seq | |
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]]) | |
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id]) | |
#targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref]) | |
return targets | |
def format_prediction_output(targets, model_path): | |
# Loading weights for the model | |
Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4)) | |
Crispr_BiLSTM.load_weights(model_path) | |
formatted_data = [] | |
for target in targets: | |
# Predict | |
encoded_seq = get_seqcode(target[0]) | |
prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0]) | |
if prediction > 100: | |
prediction = 100 | |
# Format output | |
gRNA = target[1] | |
chr = target[2] | |
start = target[3] | |
end = target[4] | |
strand = target[5] | |
transcript_id = target[6] | |
exon_id = target[7] | |
#seq_in_ref = target[8] | |
#formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction]) | |
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction]) | |
return formatted_data | |
def process_gene(gene_symbol, model_path): | |
transcripts = fetch_ensembl_transcripts(gene_symbol) | |
results = [] | |
all_exons = [] # To accumulate all exons | |
all_gene_sequences = [] # To accumulate all gene sequences | |
if transcripts: | |
for transcript in transcripts: | |
Exons = transcript['Exon'] | |
all_exons.extend(Exons) # Add all exons from this transcript to the list | |
transcript_id = transcript['id'] | |
for Exon in Exons: | |
exon_id = Exon['id'] | |
gene_sequence = fetch_ensembl_sequence(exon_id) | |
if gene_sequence: | |
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list | |
chr = Exon['seq_region_name'] | |
start = Exon['start'] | |
end = Exon['end'] | |
strand = Exon['strand'] | |
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id) | |
if targets: | |
# Predict on-target efficiency for each gRNA site | |
formatted_data = format_prediction_output(targets, model_path) | |
results.extend(formatted_data) # Flatten the results | |
else: | |
print(f"Failed to retrieve gene sequence for exon {exon_id}.") | |
else: | |
print("Failed to retrieve transcripts.") | |
# Sort results based on prediction score (assuming score is at the 8th index) | |
sorted_results = sorted(results, key=lambda x: x[8], reverse=True) | |
# Return the sorted output, combined gene sequences, and all exons | |
return sorted_results, all_gene_sequences, all_exons | |