|
import numpy as np |
|
import pandas as pd |
|
from collections import defaultdict |
|
import random |
|
import os |
|
import torch |
|
from Bio.Align.Applications import ClustalOmegaCommandline |
|
|
|
def filter_msa(msa_data, num_sequences_kept=3): |
|
""" |
|
Helper function to filter an input MSA msa_data (obtained via process_msa_data) and keep only num_sequences_kept aligned sequences. |
|
If the MSA already has fewer sequences than num_sequences_kept, we keep the MSA as is. |
|
If filtering, we always keep the first sequence of the MSA (ie. the wild type) by default. |
|
Sampling is done without replacement. |
|
""" |
|
if len(list(msa_data.keys())) <= num_sequences_kept: |
|
return msa_data |
|
filtered_msa = {} |
|
wt_name = next(iter(msa_data)) |
|
filtered_msa[wt_name] = msa_data[wt_name] |
|
del msa_data[wt_name] |
|
sequence_names = list(msa_data.keys()) |
|
sequence_names_sampled = random.sample(sequence_names,k=num_sequences_kept-1) |
|
for seq in sequence_names_sampled: |
|
filtered_msa[seq] = msa_data[seq] |
|
return filtered_msa |
|
|
|
def process_msa_data(MSA_data_file): |
|
""" |
|
Helper function that takes as input a path to a MSA file (expects a2m format) and returns a dict mapping sequence ID to the corresponding AA sequence. |
|
""" |
|
msa_data = defaultdict(str) |
|
sequence_name = "" |
|
with open(MSA_data_file, "r") as msa_file: |
|
for i, line in enumerate(msa_file): |
|
line = line.rstrip() |
|
if line.startswith(">"): |
|
sequence_name = line |
|
else: |
|
msa_data[sequence_name] += line.upper() |
|
return msa_data |
|
|
|
def get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab): |
|
vocab_size = len(vocab.keys()) |
|
num_sequences_msa = len(msa_data.keys()) |
|
one_hots = np.zeros((num_sequences_msa,MSA_end-MSA_start,vocab_size)) |
|
for i,seq_name in enumerate(msa_data.keys()): |
|
sequence = msa_data[seq_name] |
|
for j,letter in enumerate(sequence): |
|
if letter in vocab: |
|
k = vocab[letter] |
|
one_hots[i,j,k] = 1.0 |
|
return one_hots |
|
|
|
def one_hot(sequence_string,vocab): |
|
one_hots = np.zeros((len(sequence_string),len(vocab.keys()))) |
|
for j,letter in enumerate(sequence_string): |
|
if letter in vocab: |
|
k = vocab[letter] |
|
one_hots[j,k] = 1.0 |
|
return one_hots.flatten() |
|
|
|
def get_msa_prior(MSA_data_file, MSA_weight_file_name, MSA_start, MSA_end, len_target_seq, vocab, retrieval_aggregation_mode="aggregate_substitution", filter_MSA=True, verbose=False): |
|
""" |
|
Function to enable retrieval inference mode, via computation of (weighted) pseudocounts of AAs at each position of the retrieved MSA. |
|
MSA_data_file: (string) path to MSA file (expects a2m format). |
|
MSA_weight_file_name: (string) path to sequence weights in MSA. |
|
MSA_start: (int) Sequence position that the MSA starts at (1-indexing). |
|
MSA_end: (int) Sequence position that the MSA ends at (1-indexing). |
|
len_target_seq: (int) Full length of sequence to be scored. |
|
vocab: (dict) Vocabulary of the tokenizer. |
|
retrieval_aggregation_mode: (string) Mode for retrieval inference (aggregate_substitution Vs aggregate_indel). If None, places a uniform prior over each token. |
|
filter_MSA: (bool) Whether to filter out sequences with very low hamming similarity (< 0.2) to the reference sequence in the MSA (first sequence). |
|
verbose: (bool) Whether to print to the console processing details along the way. |
|
""" |
|
msa_data = process_msa_data(MSA_data_file) |
|
vocab_size = len(vocab.keys()) |
|
if verbose: print("Target seq len is {}, MSA length is {}, start position is {}, end position is {} and vocab size is {}".format(len_target_seq,MSA_end-MSA_start,MSA_start,MSA_end,vocab_size)) |
|
|
|
if filter_MSA: |
|
if verbose: print("Num sequences in MSA pre filtering: {}".format(len(msa_data.keys()))) |
|
list_sequence_names = list(msa_data.keys()) |
|
focus_sequence_name = list(msa_data.keys())[0] |
|
ref_sequence_hot = one_hot(msa_data[focus_sequence_name],vocab) |
|
for sequence_name in list_sequence_names: |
|
seq_hot = one_hot(msa_data[sequence_name],vocab) |
|
hamming_similarity_seq_ref = np.dot(ref_sequence_hot,seq_hot) / np.dot(ref_sequence_hot,ref_sequence_hot) |
|
if hamming_similarity_seq_ref < 0.2: |
|
del msa_data[sequence_name] |
|
if verbose: print("Num sequences in MSA post filtering: {}".format(len(msa_data.keys()))) |
|
|
|
if MSA_weight_file_name is not None: |
|
if verbose: print("Using weights in {} for sequences in MSA.".format(MSA_weight_file_name)) |
|
assert os.path.exists(MSA_weight_file_name), "Weights file not located on disk." |
|
MSA_EVE = MSA_processing( |
|
MSA_location=MSA_data_file, |
|
use_weights=True, |
|
weights_location=MSA_weight_file_name |
|
) |
|
|
|
dropped_sequences=0 |
|
list_sequence_names = list(msa_data.keys()) |
|
MSA_weight=[] |
|
for sequence_name in list_sequence_names: |
|
if sequence_name not in MSA_EVE.seq_name_to_sequence: |
|
dropped_sequences +=1 |
|
del msa_data[sequence_name] |
|
else: |
|
MSA_weight.append(MSA_EVE.seq_name_to_weight[sequence_name]) |
|
if verbose: print("Dropped {} sequences from MSA due to absent sequence weights".format(dropped_sequences)) |
|
else: |
|
MSA_weight = [1] * len(list(msa_data.keys())) |
|
|
|
if retrieval_aggregation_mode=="aggregate_substitution" or retrieval_aggregation_mode=="aggregate_indel": |
|
one_hots = get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab) |
|
MSA_weight = np.expand_dims(np.array(MSA_weight),axis=(1,2)) |
|
base_rate = 1e-5 |
|
base_rates = np.ones_like(one_hots) * base_rate |
|
weighted_one_hots = (one_hots + base_rates) * MSA_weight |
|
MSA_weight_norm_counts = weighted_one_hots.sum(axis=-1).sum(axis=0) |
|
MSA_weight_norm_counts = np.tile(MSA_weight_norm_counts.reshape(-1,1), (1,vocab_size)) |
|
one_hots_avg = weighted_one_hots.sum(axis=0) / MSA_weight_norm_counts |
|
msa_prior = np.zeros((len_target_seq,vocab_size)) |
|
msa_prior[MSA_start:MSA_end,:]=one_hots_avg |
|
else: |
|
msa_prior = np.ones((len_target_seq,vocab_size)) / vocab_size |
|
|
|
if verbose: |
|
for idx, position in enumerate(msa_prior): |
|
if len(position)!=25: |
|
print("Size error") |
|
if not round(position.sum(),2)==1.0: |
|
print("Position at index {} does not add up to 1: {}".format(idx, position.sum())) |
|
|
|
return msa_prior |
|
|
|
|
|
def update_retrieved_MSA_log_prior_indel(model, MSA_log_prior, MSA_start, MSA_end, full_raw_sequence): |
|
""" |
|
Function to process MSA when scoring indels. |
|
To identify positions to add / remove in the retrieved MSA, we append and align the sequence to be scored to the original MSA for that protein family with Clustal Omega. |
|
If the original MSA is relatively deep (over 100k sequences), we sample (by default) 100k rows at random from that MSA to speed computations. |
|
MSA sampling is performed only once (for the first sequence to be scored). Subsequent scoring use the same MSA sample. |
|
""" |
|
if not os.path.isdir(model.MSA_folder + os.sep + "Sampled"): |
|
os.mkdir(model.MSA_folder + os.sep + "Sampled") |
|
sampled_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Sampled_" + model.MSA_filename.split(os.sep)[-1] |
|
|
|
if not os.path.exists(sampled_MSA_location): |
|
msa_data = process_msa_data(model.MSA_filename) |
|
msa_data_sampled = filter_msa(msa_data, num_sequences_kept=100000) |
|
with open(sampled_MSA_location, 'w') as sampled_write_location: |
|
for index, key in enumerate(msa_data_sampled): |
|
key_name = ">REFERENCE_SEQUENCE" if index==0 else key |
|
msa_data_sampled[key] = msa_data_sampled[key].upper() |
|
msa_data_sampled[key] = msa_data_sampled[key].replace(".","-") |
|
sampled_write_location.write(key_name+"\n"+"\n".join([msa_data_sampled[key][i:i+80] for i in range(0, len(msa_data_sampled[key]), 80)])+"\n") |
|
|
|
seq_to_align_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Seq_to_align_" + model.MSA_filename.split(os.sep)[-1] |
|
sequence_text_split = [full_raw_sequence[i:i+80] for i in range(0, len(full_raw_sequence), 80)] |
|
sequence_text_split_split_join = "\n".join([">SEQ_TO_SCORE"]+sequence_text_split) |
|
os.system("echo '"+sequence_text_split_split_join+"' > "+seq_to_align_location) |
|
|
|
expanded_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Expanded_" + model.MSA_filename.split(os.sep)[-1] |
|
clustalw_cline = ClustalOmegaCommandline(cmd=model.config.clustal_omega_location, |
|
profile1=sampled_MSA_location, |
|
profile2=seq_to_align_location, |
|
outfile=expanded_MSA_location, |
|
force=True) |
|
stdout, stderr = clustalw_cline() |
|
msa_data = process_msa_data(expanded_MSA_location) |
|
aligned_seqA, aligned_seqB = msa_data[">SEQ_TO_SCORE"], msa_data[">REFERENCE_SEQUENCE"] |
|
try: |
|
keep_column=[] |
|
for column_index_pairwise_alignment in range(len(aligned_seqA)): |
|
if aligned_seqA[column_index_pairwise_alignment]=="-" and aligned_seqB[column_index_pairwise_alignment]=="-": |
|
continue |
|
elif aligned_seqA[column_index_pairwise_alignment]=="-": |
|
keep_column.append(False) |
|
elif aligned_seqB[column_index_pairwise_alignment]=="-": |
|
MSA_log_prior=torch.cat((MSA_log_prior[:column_index_pairwise_alignment], torch.zeros(MSA_log_prior.shape[1]).view(1,-1).cuda(), MSA_log_prior[column_index_pairwise_alignment:]),dim=0) |
|
keep_column.append(True) |
|
else: |
|
keep_column.append(True) |
|
MSA_log_prior = MSA_log_prior[keep_column] |
|
MSA_end = MSA_start + len(MSA_log_prior) |
|
except: |
|
print("Error when processing the following alignment: {}".format(expanded_MSA_location)) |
|
return MSA_log_prior, MSA_start, MSA_end |
|
|
|
class MSA_processing: |
|
def __init__(self, |
|
MSA_location="", |
|
theta=0.2, |
|
use_weights=True, |
|
weights_location="./data/weights", |
|
preprocess_MSA=True, |
|
threshold_sequence_frac_gaps=0.5, |
|
threshold_focus_cols_frac_gaps=0.3, |
|
remove_sequences_with_indeterminate_AA_in_focus_cols=True |
|
): |
|
|
|
""" |
|
This MSA_processing class is directly borrowed from the EVE codebase: https://github.com/OATML-Markslab/EVE |
|
|
|
Parameters: |
|
- msa_location: (path) Location of the MSA data. Constraints on input MSA format: |
|
- focus_sequence is the first one in the MSA data |
|
- first line is structured as follows: ">focus_seq_name/start_pos-end_pos" (e.g., >SPIKE_SARS2/310-550) |
|
- corespondding sequence data located on following line(s) |
|
- then all other sequences follow with ">name" on first line, corresponding data on subsequent lines |
|
- theta: (float) Sequence weighting hyperparameter. Generally: Prokaryotic and eukaryotic families = 0.2; Viruses = 0.01 |
|
- use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that; |
|
otherwise compute weights from scratch and store them at weights_location |
|
- weights_location: (path) Location to load from/save to the sequence weights |
|
- preprocess_MSA: (bool) performs pre-processing of MSA to remove short fragments and positions that are not well covered. |
|
- threshold_sequence_frac_gaps: (float, between 0 and 1) Threshold value to define fragments |
|
- sequences with a fraction of gap characters above threshold_sequence_frac_gaps are removed |
|
- default is set to 0.5 (i.e., fragments with 50% or more gaps are removed) |
|
- threshold_focus_cols_frac_gaps: (float, between 0 and 1) Threshold value to define focus columns |
|
- positions with a fraction of gap characters above threshold_focus_cols_pct_gaps will be set to lower case (and not included in the focus_cols) |
|
- default is set to 0.3 (i.e., focus positions are the ones with 30% of gaps or less, i.e., 70% or more residue occupancy) |
|
- remove_sequences_with_indeterminate_AA_in_focus_cols: (bool) Remove all sequences that have indeterminate AA (e.g., B, J, X, Z) at focus positions of the wild type |
|
""" |
|
np.random.seed(2021) |
|
self.MSA_location = MSA_location |
|
self.weights_location = weights_location |
|
self.theta = theta |
|
self.alphabet = "ACDEFGHIKLMNPQRSTVWY" |
|
self.use_weights = use_weights |
|
self.preprocess_MSA = preprocess_MSA |
|
self.threshold_sequence_frac_gaps = threshold_sequence_frac_gaps |
|
self.threshold_focus_cols_frac_gaps = threshold_focus_cols_frac_gaps |
|
self.remove_sequences_with_indeterminate_AA_in_focus_cols = remove_sequences_with_indeterminate_AA_in_focus_cols |
|
|
|
self.gen_alignment() |
|
|
|
def gen_alignment(self, verbose=False): |
|
""" Read training alignment and store basics in class instance """ |
|
self.aa_dict = {} |
|
for i,aa in enumerate(self.alphabet): |
|
self.aa_dict[aa] = i |
|
|
|
self.seq_name_to_sequence = defaultdict(str) |
|
name = "" |
|
with open(self.MSA_location, "r") as msa_data: |
|
for i, line in enumerate(msa_data): |
|
line = line.rstrip() |
|
if line.startswith(">"): |
|
name = line |
|
if i==0: |
|
self.focus_seq_name = name |
|
else: |
|
self.seq_name_to_sequence[name] += line |
|
|
|
|
|
|
|
if self.preprocess_MSA: |
|
msa_df = pd.DataFrame.from_dict(self.seq_name_to_sequence, orient='index', columns=['sequence']) |
|
|
|
msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".","-")).apply(lambda x: ''.join([aa.upper() for aa in x])) |
|
|
|
non_gap_wt_cols = [aa!='-' for aa in msa_df.sequence[self.focus_seq_name]] |
|
msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa for aa,non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind])) |
|
assert 0.0 <= self.threshold_sequence_frac_gaps <= 1.0,"Invalid fragment filtering parameter" |
|
assert 0.0 <= self.threshold_focus_cols_frac_gaps <= 1.0,"Invalid focus position filtering parameter" |
|
msa_array = np.array([list(seq) for seq in msa_df.sequence]) |
|
gaps_array = np.array(list(map(lambda seq: [aa=='-' for aa in seq], msa_array))) |
|
|
|
seq_gaps_frac = gaps_array.mean(axis=1) |
|
seq_below_threshold = seq_gaps_frac <= self.threshold_sequence_frac_gaps |
|
if verbose: print("Proportion of sequences dropped due to fraction of gaps: "+str(round(float(1 - seq_below_threshold.sum()/seq_below_threshold.shape)*100,2))+"%") |
|
|
|
columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0) |
|
index_cols_below_threshold = columns_gaps_frac <= self.threshold_focus_cols_frac_gaps |
|
if verbose: print("Proportion of non-focus columns removed: "+str(round(float(1 - index_cols_below_threshold.sum()/index_cols_below_threshold.shape)*100,2))+"%") |
|
|
|
msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in zip(x, index_cols_below_threshold)])) |
|
msa_df = msa_df[seq_below_threshold] |
|
|
|
self.seq_name_to_sequence = defaultdict(str) |
|
for seq_idx in range(len(msa_df['sequence'])): |
|
self.seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx] |
|
|
|
self.focus_seq = self.seq_name_to_sequence[self.focus_seq_name] |
|
self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s!='-'] |
|
self.focus_seq_trimmed = [self.focus_seq[ix] for ix in self.focus_cols] |
|
self.seq_len = len(self.focus_cols) |
|
self.alphabet_size = len(self.alphabet) |
|
|
|
|
|
focus_loc = self.focus_seq_name.split("/")[-1] |
|
start,stop = focus_loc.split("-") |
|
self.focus_start_loc = int(start) |
|
self.focus_stop_loc = int(stop) |
|
self.uniprot_focus_col_to_wt_aa_dict \ |
|
= {idx_col+int(start):self.focus_seq[idx_col] for idx_col in self.focus_cols} |
|
self.uniprot_focus_col_to_focus_idx \ |
|
= {idx_col+int(start):idx_col for idx_col in self.focus_cols} |
|
|
|
|
|
self.raw_seq_name_to_sequence = self.seq_name_to_sequence.copy() |
|
for seq_name,sequence in self.seq_name_to_sequence.items(): |
|
sequence = sequence.replace(".","-") |
|
self.seq_name_to_sequence[seq_name] = [sequence[ix].upper() for ix in self.focus_cols] |
|
|
|
|
|
if self.remove_sequences_with_indeterminate_AA_in_focus_cols: |
|
alphabet_set = set(list(self.alphabet)) |
|
seq_names_to_remove = [] |
|
for seq_name,sequence in self.seq_name_to_sequence.items(): |
|
for letter in sequence: |
|
if letter not in alphabet_set and letter != "-": |
|
seq_names_to_remove.append(seq_name) |
|
continue |
|
seq_names_to_remove = list(set(seq_names_to_remove)) |
|
for seq_name in seq_names_to_remove: |
|
del self.seq_name_to_sequence[seq_name] |
|
|
|
|
|
self.one_hot_encoding = np.zeros((len(self.seq_name_to_sequence.keys()),len(self.focus_cols),len(self.alphabet))) |
|
if verbose: print("One-hot encoded sequences shape:" + str(self.one_hot_encoding.shape)) |
|
for i,seq_name in enumerate(self.seq_name_to_sequence.keys()): |
|
sequence = self.seq_name_to_sequence[seq_name] |
|
for j,letter in enumerate(sequence): |
|
if letter in self.aa_dict: |
|
k = self.aa_dict[letter] |
|
self.one_hot_encoding[i,j,k] = 1.0 |
|
|
|
if self.use_weights: |
|
try: |
|
self.weights = np.load(file=self.weights_location) |
|
if verbose: print("Loaded sequence weights from disk") |
|
except: |
|
if verbose: print ("Computing sequence weights") |
|
list_seq = self.one_hot_encoding |
|
list_seq = list_seq.reshape((list_seq.shape[0], list_seq.shape[1] * list_seq.shape[2])) |
|
def compute_weight(seq): |
|
number_non_empty_positions = np.dot(seq,seq) |
|
if number_non_empty_positions>0: |
|
denom = np.dot(list_seq,seq) / np.dot(seq,seq) |
|
denom = np.sum(denom > 1 - self.theta) |
|
return 1/denom |
|
else: |
|
return 0.0 |
|
self.weights = np.array(list(map(compute_weight,list_seq))) |
|
np.save(file=self.weights_location, arr=self.weights) |
|
else: |
|
|
|
if verbose: print("Not weighting sequence data") |
|
self.weights = np.ones(self.one_hot_encoding.shape[0]) |
|
|
|
self.Neff = np.sum(self.weights) |
|
self.num_sequences = self.one_hot_encoding.shape[0] |
|
self.seq_name_to_weight={} |
|
for i,seq_name in enumerate(self.seq_name_to_sequence.keys()): |
|
self.seq_name_to_weight[seq_name]=self.weights[i] |
|
|
|
if verbose: |
|
print ("Neff =",str(self.Neff)) |
|
print ("Data Shape =",self.one_hot_encoding.shape) |