In [None]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/Cas9/On target')

In [None]:
import requests
import tensorflow as tf
import pandas as pd
import numpy as np
from operator import add
from functools import reduce
import random
import tabulate

from keras import Model
from keras import regularizers
from keras.optimizers import Adam
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
from keras.models import load_model
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

!pip install cyvcf2
import cyvcf2
!pip install parasail
import parasail

import re

### Data Encoding

In [None]:
ntmap = {'A': (1, 0, 0, 0),
         'C': (0, 1, 0, 0),
         'G': (0, 0, 1, 0),
         'T': (0, 0, 0, 1)
         }

def get_seqcode(seq):
    return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))

### Attention model

In [None]:
class PositionalEncoding(Layer):
    def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
        super(PositionalEncoding, self).__init__()
        self.sequence_len = sequence_len
        self.embedding_dim = embedding_dim

    def call(self, x):

        position_embedding = np.array([
            [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
            for pos in range(self.sequence_len)])

        position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2])  # dim 2i
        position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2])  # dim 2i+1
        position_embedding = tf.cast(position_embedding, dtype=tf.float32)

        return position_embedding+x

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'sequence_len' : self.sequence_len,
            'embedding_dim' : self.embedding_dim,
        })
        return config

def MultiHeadAttention_model(input_shape):
    input = Input(shape=input_shape)

    conv1 = Conv1D(256, 3, activation="relu")(input)
    pool1 = AveragePooling1D(2)(conv1)
    drop1 = Dropout(0.4)(pool1)

    conv2 = Conv1D(256, 3, activation="relu")(drop1)
    pool2 = AveragePooling1D(2)(conv2)
    drop2 = Dropout(0.4)(pool2)

    lstm = Bidirectional(LSTM(128,
                               dropout=0.5,
                               activation='tanh',
                               return_sequences=True,
                               kernel_regularizer=regularizers.l2(0.01)))(drop2)

    pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
    atten = MultiHeadAttention(num_heads=2,
                               key_dim=64,
                               dropout=0.2,
                               kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)

    flat = Flatten()(atten)

    dense1 = Dense(512,
                   kernel_regularizer=regularizers.l2(1e-4),
                   bias_regularizer=regularizers.l2(1e-4),
                   activation="relu")(flat)
    drop3 = Dropout(0.1)(dense1)

    dense2 = Dense(128,
                   kernel_regularizer=regularizers.l2(1e-4),
                   bias_regularizer=regularizers.l2(1e-4),
                   activation="relu")(drop3)
    drop4 = Dropout(0.1)(dense2)

    dense3 = Dense(256,
                   kernel_regularizer=regularizers.l2(1e-4),
                   bias_regularizer=regularizers.l2(1e-4),
                   activation="relu")(drop4)
    drop5 = Dropout(0.1)(dense3)

    output = Dense(1, activation="linear")(drop5)

    model = Model(inputs=[input], outputs=[output])
    return model

### Predict gRNA in one specific gene

In [None]:
def fetch_ensembl_transcripts(gene_symbol):
    url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        gene_data = response.json()
        if 'Transcript' in gene_data:
            return gene_data['Transcript']
        else:
            print("No transcripts found for gene:", gene_symbol)
            return None
    else:
        print(f"Error fetching gene data from Ensembl: {response.text}")
        return None

def fetch_ensembl_sequence(transcript_id):
    url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        sequence_data = response.json()
        if 'seq' in sequence_data:
            return sequence_data['seq']
        else:
            print("No sequence found for transcript:", transcript_id)
            return None
    else:
        print(f"Error fetching sequence data from Ensembl: {response.text}")
        return None


In [None]:
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20):
    targets = []
    len_sequence = len(sequence)
    #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
    dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}

    for i in range(len_sequence - len(pam) + 1):
        if sequence[i + 1:i + 3] == pam[1:]:
            if i >= target_length:
                target_seq = sequence[i - target_length:i + 3]
                if strand == -1:
                    tar_start = end - (i + 2)
                    tar_end = end - (i - target_length)
                    #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
                else:
                    tar_start = start + i - target_length
                    tar_end = start + i + 3 - 1
                    #seq_in_ref = target_seq
                gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
                #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
                targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])

    return targets


In [None]:
# Function to predict on-target efficiency and format output
def format_prediction_output(targets, model_path):
    model = MultiHeadAttention_model(input_shape=(23, 4))
    model.load_weights(model_path)

    formatted_data = []

    for target in targets:
        # Encode the gRNA sequence
        encoded_seq = get_seqcode(target[0])

        # Predict on-target efficiency using the model
        prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
        if prediction > 100:
            prediction = 100

        # Format output
        gRNA = target[1]
        chr = target[2]
        start = target[3]
        end = target[4]
        strand = target[5]
        transcript_id = target[6]
        exon_id = target[7]
        #seq_in_ref = target[8]
        #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])
        formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])

    return formatted_data

In [None]:
def gRNADesign(gene_symbol, model_path, write_to_csv=False):
    transcripts = fetch_ensembl_transcripts(gene_symbol)
    results = []
    if transcripts:
        for i in range(len(transcripts)):
            Exons = transcripts[i]['Exon']
            transcript_id = transcripts[i]['id']
            for j in range(len(Exons)):
                exon_id = Exons[j]['id']
                gene_sequence = fetch_ensembl_sequence(exon_id)
                if gene_sequence:
                    start = Exons[j]['start']
                    end = Exons[j]['end']
                    strand = Exons[j]['strand']
                    chr = Exons[j]['seq_region_name']
                    targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
                    if targets:
                        formatted_data = format_prediction_output(targets, model_path)
                        results.append(formatted_data)

    #header = ['Chr','Start','End','Strand','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','Sequence in reference genome','pred_Score']
    header = ['Chrom','Start','End','Strand','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','pred_Score']
    output = []
    for result in results:
        for item in result:
            output.append(item)
    sort_output = sorted(output, key=lambda x: x[8], reverse=True)

    if write_to_csv==True:
        pd.DataFrame(data=sort_output, columns=header).to_csv(f'/content/drive/MyDrive/Colab Notebooks/Cas9/On target/design_results/Cas9_{gene_symbol}.csv')
    else:
        return sort_output

In [None]:
# design
genes = ['TROAP','SPC24','RAD54L','MCM2','COPB2','CKAP5']
model_path = '/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Cas9_MultiHeadAttention_weights.keras'

for gene in genes:
    gRNADesign(gene, model_path, write_to_csv=True)

### Combine with VCF information

##### Predict cell type-specific gRNA

In [None]:
def fetch_ensembl_transcripts(gene_symbol):
    url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        gene_data = response.json()
        if 'Transcript' in gene_data:
            return gene_data['Transcript']
        else:
            print("No transcripts found for gene:", gene_symbol)
            return None
    else:
        print(f"Error fetching gene data from Ensembl: {response.text}")
        return None

def fetch_ensembl_sequence(transcript_id):
    url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        sequence_data = response.json()
        if 'seq' in sequence_data:
            return sequence_data['seq']
        else:
            print("No sequence found for transcript:", transcript_id)
            return None
    else:
        print(f"Error fetching sequence data from Ensembl: {response.text}")
        return None


In [None]:
def apply_mutation(ref_sequence, offset, ref, alt):
    """
    Apply a single mutation to the sequence.
    """
    if len(ref) == len(alt) and alt != "*":  # SNP
        mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]

    elif len(ref) < len(alt):  # Insertion
        mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]

    elif len(ref) == len(alt) and alt == "*":  # Deletion
        mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]

    elif len(ref) > len(alt) and alt != "*":  # Deletion
        mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]

    elif len(ref) > len(alt) and alt == "*":  # Deletion
        mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]

    return mutated_seq

In [None]:
def construct_combinations(sequence, mutations):
    """
    Construct all combinations of mutations.
    mutations is a list of tuples (position, ref, [alts])
    """
    if not mutations:
        return [sequence]

    # Take the first mutation and recursively construct combinations for the rest
    first_mutation = mutations[0]
    rest_mutations = mutations[1:]
    offset, ref, alts = first_mutation

    sequences = []
    for alt in alts:
        mutated_sequence = apply_mutation(sequence, offset, ref, alt)
        sequences.extend(construct_combinations(mutated_sequence, rest_mutations))

    return sequences


In [None]:
def needleman_wunsch_alignment(query_seq, ref_seq):
    """
    Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
    Use this position to represent the position of target sequence with mutations
    """
    # Needleman-Wunsch alignment
    alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)

    # extract CIGAR object
    cigar = alignment.cigar
    cigar_string = cigar.decode.decode("utf-8")

    # record ref_pos
    ref_pos = 0

    matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
    max_num_before_equal = 0
    max_equal_index = -1
    total_before_max_equal = 0

    for i, (num_str, op) in enumerate(matches):
        num = int(num_str)
        if op == '=':
            if num > max_num_before_equal:
                max_num_before_equal = num
                max_equal_index = i
    total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))

    ref_pos = total_before_max_equal

    return ref_pos


In [None]:
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
                            exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20):
    # initialization
    mutated_sequences = [ref_sequence]

    # find mutations within interested region
    mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
    if mutations:
        # find mutations
        mutation_list = []
        for mutation in mutations:
            offset = mutation.POS - start
            ref = mutation.REF
            alts = mutation.ALT[:-1]
            mutation_list.append((offset, ref, alts))

        # replace reference sequence of mutation
        mutated_sequences = construct_combinations(ref_sequence, mutation_list)

    # find gRNA in ref_sequence or all mutated_sequences
    targets = []
    for seq in mutated_sequences:
        len_sequence = len(seq)
        dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
        for i in range(len_sequence - len(pam) + 1):
            if seq[i + 1:i + 3] == pam[1:]:
                if i >= target_length:
                    target_seq = seq[i - target_length:i + 3]
                    pos = ref_sequence.find(target_seq)
                    if pos != -1:
                        is_mut = False
                        if strand == -1:
                            tar_start = end - pos - target_length - 2
                        else:
                            tar_start = start + pos
                    else:
                        is_mut = True
                        nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
                        if strand == -1:
                            tar_start = str(end - nw_pos - target_length - 2) + '*'
                        else:
                            tar_start = str(start + nw_pos) + '*'
                    gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])
                    targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])

    # filter duplicated targets
    unique_targets_set = set(tuple(element) for element in targets)
    unique_targets = [list(element) for element in unique_targets_set]

    return unique_targets

In [None]:
def format_prediction_output_with_mutation(targets, model_path):
    model = MultiHeadAttention_model(input_shape=(23, 4))
    model.load_weights(model_path)

    formatted_data = []

    for target in targets:
        # Encode the gRNA sequence
        encoded_seq = get_seqcode(target[0])


        # Predict on-target efficiency using the model
        prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
        if prediction > 100:
            prediction = 100

        # Format output
        gRNA = target[1]
        exon_chr = target[2]
        strand = target[3]
        tar_start = target[4]
        transcript_id = target[5]
        exon_id = target[6]
        gene_symbol = target[7]
        is_mut = target[8]
        formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,
                               exon_id, target[0], gRNA, prediction, is_mut])

    return formatted_data

In [None]:
def gRNADesign_mutation(gene_symbol, vcf_reader, model_path, write_to_csv=False):
    results = []

    transcripts = fetch_ensembl_transcripts(gene_symbol)
    if transcripts:
        for transcript in transcripts:
            Exons = transcript['Exon']
            transcript_id = transcript['id']

            for Exon in Exons:
                exon_id = Exon['id']
                exon_chr = Exon['seq_region_name']
                start = Exon['start']
                end = Exon['end']
                strand = Exon['strand']
                gene_sequence = fetch_ensembl_sequence(exon_id) # reference exon sequence

                if gene_sequence:
                    targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,
                                                      transcript_id, exon_id, gene_symbol, vcf_reader)
                    if targets:
                        # Predict on-target efficiency for each gRNA site
                        formatted_data = format_prediction_output_with_mutation(targets, model_path)
                        results.append(formatted_data)
    header = ['Gene','Chrom','Strand','Start','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','pred_Score','Is_mutation']
    output = []
    for result in results:
        for item in result:
            output.append(item)
    sort_output = sorted(output, key=lambda x: x[8], reverse=True)

    if write_to_csv==True:
        pd.DataFrame(data=sort_output, columns=header).to_csv(f'/content/drive/MyDrive/Colab Notebooks/Cas9/On target/design_results/Cas9_{gene_symbol}_mut.csv')
    else:
        return sort_output

In [None]:
# read VCF file
vcf_reader = cyvcf2.VCF('/content/drive/MyDrive/Colab Notebooks/CRISPR_data/SRR25934512.filter.snps.indels.vcf.gz')

In [None]:
# design
genes = ['TROAP','SPC24','RAD54L','MCM2','COPB2','CKAP5']
model_path = '/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Cas9_MultiHeadAttention_weights.keras'

for gene in genes:
    gRNADesign_mutation(gene, vcf_reader, model_path, write_to_csv=True)