Spaces:
Sleeping
Sleeping
import requests | |
import tensorflow as tf | |
import pandas as pd | |
import numpy as np | |
from operator import add | |
from functools import reduce | |
from Bio import SeqIO | |
from Bio.SeqRecord import SeqRecord | |
from Bio.SeqFeature import SeqFeature, FeatureLocation | |
from Bio.Seq import Seq | |
from keras.models import load_model | |
import random | |
# configure GPUs | |
for gpu in tf.config.list_physical_devices('GPU'): | |
tf.config.experimental.set_memory_growth(gpu, enable=True) | |
if len(tf.config.list_physical_devices('GPU')) > 0: | |
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU') | |
ntmap = {'A': (1, 0, 0, 0), | |
'C': (0, 1, 0, 0), | |
'G': (0, 0, 1, 0), | |
'T': (0, 0, 0, 1) | |
} | |
def get_seqcode(seq): | |
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape( | |
(1, len(seq), -1)) | |
from keras.models import load_model | |
class DCModelOntar: | |
def __init__(self, ontar_model_dir, is_reg=False): | |
self.model = load_model(ontar_model_dir) | |
def ontar_predict(self, x, channel_first=True): | |
if channel_first: | |
x = x.transpose([0, 2, 3, 1]) | |
yp = self.model.predict(x) | |
return yp.ravel() | |
def fetch_ensembl_transcripts(gene_symbol): | |
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
gene_data = response.json() | |
if 'Transcript' in gene_data: | |
return gene_data['Transcript'] | |
else: | |
print("No transcripts found for gene:", gene_symbol) | |
return None | |
else: | |
print(f"Error fetching gene data from Ensembl: {response.text}") | |
return None | |
def fetch_ensembl_sequence(transcript_id): | |
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
sequence_data = response.json() | |
if 'seq' in sequence_data: | |
return sequence_data['seq'] | |
else: | |
print("No sequence found for transcript:", transcript_id) | |
return None | |
else: | |
print(f"Error fetching sequence data from Ensembl: {response.text}") | |
return None | |
def fetch_ensembl_cds(transcript_id): | |
url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=cds;content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
cds_data = response.json() | |
return cds_data | |
else: | |
print(f"Error fetching CDS data from Ensembl: {response.text}") | |
return [] | |
def find_crispr_targets(sequence, chr, start, strand, transcript_id, pam="NGG", target_length=20): | |
targets = [] | |
len_sequence = len(sequence) | |
complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'} | |
if strand == -1: | |
sequence = ''.join([complement[base] for base in reversed(sequence)]) | |
for i in range(len_sequence - len(pam) + 1): | |
if sequence[i + 1:i + 3] == pam[1:]: | |
if i >= target_length: | |
target_seq = sequence[i - target_length:i + 3] | |
tar_start = start + i - target_length | |
tar_end = start + i + 3 | |
sgRNA = sequence[i - target_length:i] | |
targets.append([target_seq, sgRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id]) | |
return targets | |
# Function to predict on-target efficiency and format output | |
def format_prediction_output(targets, model_path): | |
dcModel = DCModelOntar(model_path) | |
formatted_data = [] | |
for target in targets: | |
# Encode the gRNA sequence | |
encoded_seq = get_seqcode(target[0]).reshape(-1,4,1,23) | |
# Predict on-target efficiency using the model | |
prediction = dcModel.ontar_predict(encoded_seq) | |
# Format output | |
sgRNA = target[1] | |
chr = target[2] | |
start = target[3] | |
end = target[4] | |
strand = target[5] | |
transcript_id = target[6] | |
formatted_data.append([chr, start, end, strand, transcript_id, target[0], sgRNA, prediction[0]]) | |
return formatted_data | |
def process_gene(gene_symbol, model_path): | |
transcripts = fetch_ensembl_transcripts(gene_symbol) | |
results = [] | |
if transcripts: | |
for i in range(len(transcripts)): | |
Exons = transcripts[i]['Exon'] | |
cds_list = fetch_ensembl_cds(transcript_id) | |
transcript_id = transcripts[i]['display_name'] | |
for j in range(len(Exons)): | |
exon_id = Exons[j]['id'] | |
gene_sequence = fetch_ensembl_sequence(exon_id) | |
if gene_sequence: | |
start = Exons[j]['start'] | |
strand = Exons[j]['strand'] | |
chr = Exons[j]['seq_region_name'] | |
targets = find_crispr_targets(gene_sequence, chr, start, strand, transcript_id) | |
if not targets: | |
print("No gRNA sites found in the gene sequence.") | |
else: | |
# Predict on-target efficiency for each gRNA site | |
formatted_data = format_prediction_output(targets,model_path) | |
results.append(formatted_data) | |
else: | |
print("Failed to retrieve gene sequence.") | |
else: | |
print("Failed to retrieve transcripts.") | |
# Note: Returning last exon's sequence, might need adjustment based on use-case | |
return results, gene_sequence, Exons, cds_list | |
def create_genbank_features(formatted_data): | |
features = [] | |
for data in formatted_data: | |
# Strand conversion to Biopython's convention | |
strand = 1 if data[3] == '+' else -1 | |
location = FeatureLocation(start=int(data[1]), end=int(data[2]), strand=strand) | |
feature = SeqFeature(location=location, type="misc_feature", qualifiers={ | |
'label': data[5], # Use gRNA as the label | |
'target': data[4], # Include the target sequence | |
'note': f"Prediction: {data[6]}" # Include the prediction score | |
}) | |
features.append(feature) | |
return features | |
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path): | |
features = [] | |
for index, row in df.iterrows(): | |
location = FeatureLocation(start=int(row["Start Pos"]), | |
end=int(row["End Pos"]), | |
strand=int(row["Strand"])) | |
feature = SeqFeature(location=location, type="gene", qualifiers={ | |
'locus_tag': row["Gene ID"], # Assuming Gene ID is equivalent to Chromosome here | |
'note': f"gRNA: {row['gRNA']}, Prediction: {row['Prediction']}" | |
}) | |
features.append(feature) | |
record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol, | |
description='CRISPR Cas9 predicted targets', features=features) | |
# Add the missing molecule_type annotation | |
record.annotations["molecule_type"] = "DNA" | |
SeqIO.write(record, output_path, "genbank") | |
def create_bed_file_from_df(df, output_path): | |
with open(output_path, 'w') as bed_file: | |
for index, row in df.iterrows(): | |
chrom = row["Gene ID"] | |
start = int(row["Start Pos"]) | |
end = int(row["End Pos"]) | |
strand = '+' if row["Strand"] == '+' else '-' | |
gRNA = row["gRNA"] | |
score = str(row["Prediction"]) # Ensure score is converted to string if not already | |
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n") | |
def create_csv_from_df(df, output_path): | |
df.to_csv(output_path, index=False) |