Spaces:
Sleeping
Sleeping
import requests | |
import tensorflow as tf | |
import pandas as pd | |
import numpy as np | |
from operator import add | |
from functools import reduce | |
import random | |
import tabulate | |
from keras import Model | |
from keras import regularizers | |
from keras.optimizers import Adam | |
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax | |
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout | |
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape | |
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer | |
from keras.models import load_model | |
from keras.callbacks import EarlyStopping, ReduceLROnPlateau | |
from Bio import SeqIO | |
from Bio.SeqRecord import SeqRecord | |
from Bio.SeqFeature import SeqFeature, FeatureLocation | |
from Bio.Seq import Seq | |
import cyvcf2 | |
import parasail | |
import re | |
ntmap = {'A': (1, 0, 0, 0), | |
'C': (0, 1, 0, 0), | |
'G': (0, 0, 1, 0), | |
'T': (0, 0, 0, 1) | |
} | |
def get_seqcode(seq): | |
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1)) | |
class PositionalEncoding(Layer): | |
def __init__(self, sequence_len=None, embedding_dim=None,**kwargs): | |
super(PositionalEncoding, self).__init__() | |
self.sequence_len = sequence_len | |
self.embedding_dim = embedding_dim | |
def call(self, x): | |
position_embedding = np.array([ | |
[pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)] | |
for pos in range(self.sequence_len)]) | |
position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i | |
position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1 | |
position_embedding = tf.cast(position_embedding, dtype=tf.float32) | |
return position_embedding+x | |
def get_config(self): | |
config = super().get_config().copy() | |
config.update({ | |
'sequence_len' : self.sequence_len, | |
'embedding_dim' : self.embedding_dim, | |
}) | |
return config | |
def MultiHeadAttention_model(input_shape): | |
input = Input(shape=input_shape) | |
conv1 = Conv1D(256, 3, activation="relu")(input) | |
pool1 = AveragePooling1D(2)(conv1) | |
drop1 = Dropout(0.4)(pool1) | |
conv2 = Conv1D(256, 3, activation="relu")(drop1) | |
pool2 = AveragePooling1D(2)(conv2) | |
drop2 = Dropout(0.4)(pool2) | |
lstm = Bidirectional(LSTM(128, | |
dropout=0.5, | |
activation='tanh', | |
return_sequences=True, | |
kernel_regularizer=regularizers.l2(0.01)))(drop2) | |
pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm) | |
atten = MultiHeadAttention(num_heads=2, | |
key_dim=64, | |
dropout=0.2, | |
kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding) | |
flat = Flatten()(atten) | |
dense1 = Dense(512, | |
kernel_regularizer=regularizers.l2(1e-4), | |
bias_regularizer=regularizers.l2(1e-4), | |
activation="relu")(flat) | |
drop3 = Dropout(0.1)(dense1) | |
dense2 = Dense(128, | |
kernel_regularizer=regularizers.l2(1e-4), | |
bias_regularizer=regularizers.l2(1e-4), | |
activation="relu")(drop3) | |
drop4 = Dropout(0.1)(dense2) | |
dense3 = Dense(256, | |
kernel_regularizer=regularizers.l2(1e-4), | |
bias_regularizer=regularizers.l2(1e-4), | |
activation="relu")(drop4) | |
drop5 = Dropout(0.1)(dense3) | |
output = Dense(1, activation="linear")(drop5) | |
model = Model(inputs=[input], outputs=[output]) | |
return model | |
def fetch_ensembl_transcripts(gene_symbol): | |
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
gene_data = response.json() | |
if 'Transcript' in gene_data: | |
return gene_data['Transcript'] | |
else: | |
print("No transcripts found for gene:", gene_symbol) | |
return None | |
else: | |
print(f"Error fetching gene data from Ensembl: {response.text}") | |
return None | |
def fetch_ensembl_sequence(transcript_id): | |
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
sequence_data = response.json() | |
if 'seq' in sequence_data: | |
return sequence_data['seq'] | |
else: | |
print("No sequence found for transcript:", transcript_id) | |
return None | |
else: | |
print(f"Error fetching sequence data from Ensembl: {response.text}") | |
return None | |
def apply_mutation(ref_sequence, offset, ref, alt): | |
""" | |
Apply a single mutation to the sequence. | |
""" | |
if len(ref) == len(alt) and alt != "*": # SNP | |
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):] | |
elif len(ref) < len(alt): # Insertion | |
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:] | |
elif len(ref) == len(alt) and alt == "*": # Deletion | |
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:] | |
elif len(ref) > len(alt) and alt != "*": # Deletion | |
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):] | |
elif len(ref) > len(alt) and alt == "*": # Deletion | |
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):] | |
return mutated_seq | |
def construct_combinations(sequence, mutations): | |
""" | |
Construct all combinations of mutations. | |
mutations is a list of tuples (position, ref, [alts]) | |
""" | |
if not mutations: | |
return [sequence] | |
# Take the first mutation and recursively construct combinations for the rest | |
first_mutation = mutations[0] | |
rest_mutations = mutations[1:] | |
offset, ref, alts = first_mutation | |
sequences = [] | |
for alt in alts: | |
mutated_sequence = apply_mutation(sequence, offset, ref, alt) | |
sequences.extend(construct_combinations(mutated_sequence, rest_mutations)) | |
return sequences | |
def needleman_wunsch_alignment(query_seq, ref_seq): | |
""" | |
Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq | |
Use this position to represent the position of target sequence with mutations | |
""" | |
# Needleman-Wunsch alignment | |
alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62) | |
# extract CIGAR object | |
cigar = alignment.cigar | |
cigar_string = cigar.decode.decode("utf-8") | |
# record ref_pos | |
ref_pos = 0 | |
matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string) | |
max_num_before_equal = 0 | |
max_equal_index = -1 | |
total_before_max_equal = 0 | |
for i, (num_str, op) in enumerate(matches): | |
num = int(num_str) | |
if op == '=': | |
if num > max_num_before_equal: | |
max_num_before_equal = num | |
max_equal_index = i | |
total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index)) | |
ref_pos = total_before_max_equal | |
return ref_pos | |
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id, | |
exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20): | |
# initialization | |
mutated_sequences = [ref_sequence] | |
# find mutations within interested region | |
mutations = vcf_reader(f"{exon_chr}:{start}-{end}") | |
if mutations: | |
# find mutations | |
mutation_list = [] | |
for mutation in mutations: | |
offset = mutation.POS - start | |
ref = mutation.REF | |
alts = mutation.ALT[:-1] | |
mutation_list.append((offset, ref, alts)) | |
# replace reference sequence of mutation | |
mutated_sequences = construct_combinations(ref_sequence, mutation_list) | |
# find gRNA in ref_sequence or all mutated_sequences | |
targets = [] | |
for seq in mutated_sequences: | |
len_sequence = len(seq) | |
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'} | |
for i in range(len_sequence - len(pam) + 1): | |
if seq[i + 1:i + 3] == pam[1:]: | |
if i >= target_length: | |
target_seq = seq[i - target_length:i + 3] | |
pos = ref_sequence.find(target_seq) | |
if pos != -1: | |
is_mut = False | |
if strand == -1: | |
tar_start = end - pos - target_length - 2 | |
else: | |
tar_start = start + pos | |
else: | |
is_mut = True | |
nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence) | |
if strand == -1: | |
tar_start = str(end - nw_pos - target_length - 2) + '*' | |
else: | |
tar_start = str(start + nw_pos) + '*' | |
gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]]) | |
targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut]) | |
# filter duplicated targets | |
unique_targets_set = set(tuple(element) for element in targets) | |
unique_targets = [list(element) for element in unique_targets_set] | |
return unique_targets | |
def format_prediction_output_with_mutation(targets, model_path): | |
model = MultiHeadAttention_model(input_shape=(23, 4)) | |
model.load_weights(model_path) | |
formatted_data = [] | |
for target in targets: | |
# Encode the gRNA sequence | |
encoded_seq = get_seqcode(target[0]) | |
# Predict on-target efficiency using the model | |
prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0]) | |
if prediction > 100: | |
prediction = 100 | |
# Format output | |
gRNA = target[1] | |
exon_chr = target[2] | |
strand = target[3] | |
tar_start = target[4] | |
transcript_id = target[5] | |
exon_id = target[6] | |
gene_symbol = target[7] | |
is_mut = target[8] | |
formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id, | |
exon_id, target[0], gRNA, prediction, is_mut]) | |
return formatted_data | |
def process_gene(gene_symbol, vcf_reader, model_path): | |
transcripts = fetch_ensembl_transcripts(gene_symbol) | |
results = [] | |
all_exons = [] # To accumulate all exons | |
all_gene_sequences = [] # To accumulate all gene sequences | |
if transcripts: | |
for transcript in transcripts: | |
Exons = transcript['Exon'] | |
all_exons.extend(Exons) # Add all exons from this transcript to the list | |
transcript_id = transcript['id'] | |
for Exon in Exons: | |
exon_id = Exon['id'] | |
gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence | |
if gene_sequence: | |
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list | |
exon_chr = Exon['seq_region_name'] | |
start = Exon['start'] | |
end = Exon['end'] | |
strand = Exon['strand'] | |
targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand, | |
transcript_id, exon_id, gene_symbol, vcf_reader) | |
if targets: | |
# Predict on-target efficiency for each gRNA site including mutations | |
formatted_data = format_prediction_output_with_mutation(targets, model_path) | |
results.extend(formatted_data) | |
else: | |
print(f"Failed to retrieve gene sequence for exon {exon_id}.") | |
else: | |
print("Failed to retrieve transcripts.") | |
# Return the sorted output, combined gene sequences, and all exons | |
return results, all_gene_sequences, all_exons | |
def create_genbank_features(data): | |
features = [] | |
# If the input data is a DataFrame, convert it to a list of lists | |
if isinstance(data, pd.DataFrame): | |
formatted_data = data.values.tolist() | |
elif isinstance(data, list): | |
formatted_data = data | |
else: | |
raise TypeError("Data should be either a list or a pandas DataFrame.") | |
for row in formatted_data: | |
try: | |
start = int(row[1]) | |
end = start + len(row[6]) # Calculate the end position based on the target sequence length | |
except ValueError as e: | |
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}") | |
continue | |
strand = 1 if row[3] == '1' else -1 | |
location = FeatureLocation(start=start, end=end, strand=strand) | |
is_mutation = 'Yes' if row[9] else 'No' | |
feature = SeqFeature(location=location, type="misc_feature", qualifiers={ | |
'label': row[7], # Use gRNA as the label | |
'note': f"Prediction: {row[8]}, Mutation: {is_mutation}" # Include the prediction score and mutation status | |
}) | |
features.append(feature) | |
return features | |
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path): | |
# Ensure gene_sequence is a string before creating Seq object | |
if not isinstance(gene_sequence, str): | |
gene_sequence = str(gene_sequence) | |
features = create_genbank_features(df) | |
# Now gene_sequence is guaranteed to be a string, suitable for Seq | |
seq_obj = Seq(gene_sequence) | |
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol, | |
description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features) | |
record.annotations["molecule_type"] = "DNA" | |
SeqIO.write(record, output_path, "genbank") | |
def create_bed_file_from_df(df, output_path): | |
with open(output_path, 'w') as bed_file: | |
for index, row in df.iterrows(): | |
chrom = row["Chr"] | |
start = int(row["Target Start"]) | |
end = start + len(row["Target"]) # Calculate the end position based on the target sequence length | |
strand = '+' if row["Strand"] == '1' else '-' | |
gRNA = row["gRNA"] | |
score = str(row["Prediction"]) | |
is_mutation = 'Yes' if row["Is Mutation"] else 'No' | |
# transcript_id is not typically part of the standard BED columns but added here for completeness | |
transcript_id = row["Transcript"] | |
# Writing only standard BED columns; additional columns can be appended as needed | |
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{is_mutation}\n") | |
def create_csv_from_df(df, output_path): | |
df.to_csv(output_path, index=False) | |