CRISPRTool / cas9off.py
supercat666's picture
fixed cas9off
a5afc1a
raw
history blame
5.41 kB
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import argparse
# configure GPUs
for gpu in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, enable=True)
if len(tf.config.list_physical_devices('GPU')) > 0:
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
class Encoder:
def __init__(self, on_seq, off_seq, with_category = False, label = None, with_reg_val = False, value = None):
tlen = 24
self.on_seq = "-" *(tlen-len(on_seq)) + on_seq
self.off_seq = "-" *(tlen-len(off_seq)) + off_seq
self.encoded_dict_indel = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
'G': [0, 0, 1, 0, 0], 'C': [0, 0, 0, 1, 0], '_': [0, 0, 0, 0, 1], '-': [0, 0, 0, 0, 0]}
self.direction_dict = {'A':5, 'G':4, 'C':3, 'T':2, '_':1}
if with_category:
self.label = label
if with_reg_val:
self.value = value
self.encode_on_off_dim7()
def encode_sgRNA(self):
code_list = []
encoded_dict = self.encoded_dict_indel
sgRNA_bases = list(self.on_seq)
for i in range(len(sgRNA_bases)):
if sgRNA_bases[i] == "N":
sgRNA_bases[i] = list(self.off_seq)[i]
code_list.append(encoded_dict[sgRNA_bases[i]])
self.sgRNA_code = np.array(code_list)
def encode_off(self):
code_list = []
encoded_dict = self.encoded_dict_indel
off_bases = list(self.off_seq)
for i in range(len(off_bases)):
code_list.append(encoded_dict[off_bases[i]])
self.off_code = np.array(code_list)
def encode_on_off_dim7(self):
self.encode_sgRNA()
self.encode_off()
on_bases = list(self.on_seq)
off_bases = list(self.off_seq)
on_off_dim7_codes = []
for i in range(len(on_bases)):
diff_code = np.bitwise_or(self.sgRNA_code[i], self.off_code[i])
on_b = on_bases[i]
off_b = off_bases[i]
if on_b == "N":
on_b = off_b
dir_code = np.zeros(2)
if on_b == "-" or off_b == "-" or self.direction_dict[on_b] == self.direction_dict[off_b]:
pass
else:
if self.direction_dict[on_b] > self.direction_dict[off_b]:
dir_code[0] = 1
else:
dir_code[1] = 1
on_off_dim7_codes.append(np.concatenate((diff_code, dir_code)))
self.on_off_code = np.array(on_off_dim7_codes)
def encode_on_off_seq_pairs(input_file):
inputs = pd.read_csv(input_file, delimiter=",", header=None, names=['on_seq', 'off_seq'])
input_codes = []
for idx, row in inputs.iterrows():
on_seq = row['on_seq']
off_seq = row['off_seq']
en = Encoder(on_seq=on_seq, off_seq=off_seq)
input_codes.append(en.on_off_code)
input_codes = np.array(input_codes)
input_codes = input_codes.reshape((len(input_codes), 1, 24, 7))
y_pred = CRISPR_net_predict(input_codes)
inputs['CRISPR_Net_score'] = y_pred
inputs.to_csv("CRISPR_net_results.csv", index=False)
def CRISPR_net_predict(X_test):
json_file = open("cas9_model/CRISPR_Net_CIRCLE_elevation_SITE_structure.json", 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = tf.keras.models.model_from_json(loaded_model_json) # Updated for TensorFlow 2
loaded_model.load_weights("cas9_model/CRISPR_Net_CIRCLE_elevation_SITE_weights.h5")
y_pred = loaded_model.predict(X_test).flatten()
return y_pred
def process_input_and_predict(input_data, input_type='manual'):
if input_type == 'manual':
sequences = [seq.split(',') for seq in input_data.split('\n')]
inputs = pd.DataFrame(sequences, columns=['on_seq', 'off_seq'])
elif input_type == 'file':
inputs = pd.read_csv(input_data, delimiter=",", header=None, names=['on_seq', 'off_seq'])
valid_inputs = []
input_codes = []
for idx, row in inputs.iterrows():
on_seq = row['on_seq']
off_seq = row['off_seq']
if not on_seq or not off_seq:
continue
en = Encoder(on_seq=on_seq, off_seq=off_seq)
input_codes.append(en.on_off_code)
valid_inputs.append((on_seq, off_seq))
input_codes = np.array(input_codes)
input_codes = input_codes.reshape((len(input_codes), 1, 24, 7))
y_pred = CRISPR_net_predict(input_codes)
# Create a new DataFrame from valid inputs and predictions
result_df = pd.DataFrame(valid_inputs, columns=['on_seq', 'off_seq'])
result_df['CRISPR_Net_score'] = y_pred
return result_df
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="CRISPR-Net v1.0 (Aug 10 2019)")
parser.add_argument("input_file",
help="input_file example (on-target seq, off-target seq):\n GAGT_CCGAGCAGAAGAAGAATGG,GAGTACCAAGTAGAAGAAAAATTT\n"
"GTTGCCCCACAGGGCAGTAAAGG,GTGGACACCCCGGGCAGGAAAGG\n"
"GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG")
args = parser.parse_args()
file = args.input_file
if not os.path.exists(args.input_file):
print("File doesn't exist!")
else:
encode_on_off_seq_pairs(file)
tf.keras.backend.clear_session()