|
import os |
|
import tqdm |
|
import re |
|
import numpy as np |
|
import pandas as pd |
|
|
|
import torch |
|
from torch.nn import CrossEntropyLoss, NLLLoss |
|
from torch.utils.data.sampler import Sampler, SequentialSampler |
|
|
|
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerFast |
|
from datasets import Dataset |
|
|
|
AA_vocab = "ACDEFGHIKLMNPQRSTVWY" |
|
|
|
def get_mutated_sequence(focus_seq, mutant, start_idx=1, AA_vocab=AA_vocab): |
|
""" |
|
Helper function that mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only). |
|
Mutation triplet are typically based on 1-indexing: start_idx is used for switching to 0-indexing. |
|
""" |
|
mutated_seq = list(focus_seq) |
|
for mutation in mutant.split(":"): |
|
try: |
|
from_AA, position, to_AA = mutation[0], int(mutation[1:-1]), mutation[-1] |
|
except: |
|
print("Issue with mutant: "+str(mutation)) |
|
relative_position = position - start_idx |
|
assert (from_AA==focus_seq[relative_position]), "Invalid from_AA or mutant position: "+str(mutation)+" from_AA: "+str(from_AA) + " relative pos: "+str(relative_position) + " focus_seq: "+str(focus_seq) |
|
assert (to_AA in AA_vocab) , "Mutant to_AA is invalid: "+str(mutation) |
|
mutated_seq[relative_position] = to_AA |
|
return "".join(mutated_seq) |
|
|
|
def nanmean(v, *args, inplace=False, **kwargs): |
|
if not inplace: |
|
v = v.clone() |
|
is_nan = torch.isnan(v) |
|
v[is_nan] = 0 |
|
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs) |
|
|
|
def nansum(v, *args, inplace=False, **kwargs): |
|
if not inplace: |
|
v = v.clone() |
|
is_nan = torch.isnan(v) |
|
v[is_nan] = 0 |
|
return v.sum(*args, **kwargs) |
|
|
|
def get_optimal_window(mutation_position_relative, seq_len_wo_special, model_window): |
|
""" |
|
Helper function that selects an optimal sequence window that fits the maximum model context size. |
|
If the sequence length is less than the maximum context size, the full sequence is returned. |
|
""" |
|
half_model_window = model_window // 2 |
|
if seq_len_wo_special <= model_window: |
|
return [0,seq_len_wo_special] |
|
elif mutation_position_relative < half_model_window: |
|
return [0,model_window] |
|
elif mutation_position_relative >= seq_len_wo_special - half_model_window: |
|
return [seq_len_wo_special - model_window, seq_len_wo_special] |
|
else: |
|
return [max(0,mutation_position_relative-half_model_window), min(seq_len_wo_special,mutation_position_relative+half_model_window)] |
|
|
|
def sequence_replace_single(sequence, char_to_replace, char_replacements): |
|
char_replacements = list(char_replacements) |
|
positions = [m.start() for m in re.finditer(char_to_replace, sequence)] |
|
replacements = np.random.choice(a=char_replacements, size=len(positions), replace=True) |
|
sequence=list(sequence) |
|
for idx, position in enumerate(positions): |
|
sequence[position]=replacements[idx] |
|
return ''.join(sequence) |
|
|
|
def sequence_replace(sequences, char_to_replace, char_replacements): |
|
""" |
|
Helper function that replaces all Amino Acids passsed in via char_to_replace (as a string of AAs) with Amino Acids sampled from char_replacements (also a string of eligible AAs). |
|
""" |
|
return [sequence_replace_single(sequence, char_to_replace, char_replacements) for sequence in sequences] |
|
|
|
def get_tranception_scores_mutated_sequences(model, mutated_sequence_df, batch_size_inference, score_var_name, len_target_seq, num_workers=10, reverse=False, indel_mode=False): |
|
""" |
|
Helper function that takes as input a set of mutated sequences (in a pandas dataframe) and returns scores for each mutation (delta log likelihood wrt wild type sequence). |
|
""" |
|
scores = {} |
|
scores['mutant']=[] |
|
scores['window_start']=[] |
|
scores['window_end']=[] |
|
scores['score']=[] |
|
with torch.no_grad(): |
|
ds = Dataset.from_pandas(mutated_sequence_df) |
|
ds.set_transform(model.encode_batch) |
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=model.config.tokenizer, |
|
mlm=False) |
|
sampler = SequentialSampler(ds) |
|
ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size_inference, sampler=sampler, collate_fn=data_collator, num_workers=num_workers, pin_memory=True, drop_last=False) |
|
mutant_index=0 |
|
for encoded_batch in tqdm.tqdm(ds_loader): |
|
full_batch_length = len(encoded_batch['input_ids']) |
|
scores['mutant'] += list(mutated_sequence_df['mutant'][mutant_index:mutant_index+full_batch_length]) |
|
window_start = np.array(mutated_sequence_df['window_start'][mutant_index:mutant_index+full_batch_length]) |
|
scores['window_start'] += list(window_start) |
|
window_end = np.array(mutated_sequence_df['window_end'][mutant_index:mutant_index+full_batch_length]) |
|
scores['window_end'] += list(window_end) |
|
full_raw_sequence = np.array(mutated_sequence_df['full_raw_sequence'][mutant_index:mutant_index+full_batch_length]) |
|
for k, v in encoded_batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
encoded_batch[k] = v.to(model.device) |
|
shift_labels = encoded_batch['labels'][..., 1:].contiguous() |
|
if (hasattr(model.config,"retrieval_aggregation_mode")) and (model.config.retrieval_aggregation_mode is not None): |
|
if reverse: |
|
encoded_batch['flip']=torch.tensor([1]*full_batch_length) |
|
encoded_batch['start_slice']=window_start |
|
encoded_batch['end_slice']=window_end |
|
encoded_batch['full_raw_sequence'] = full_raw_sequence |
|
fused_shift_log_probas=model(**encoded_batch,return_dict=True).fused_shift_log_probas |
|
loss_fct = NLLLoss(reduction='none') |
|
loss = - loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1]) |
|
else: |
|
lm_logits=model(**encoded_batch,return_dict=True).logits |
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
loss_fct = CrossEntropyLoss(reduction='none') |
|
loss = - loss_fct(input=shift_logits.view(-1, shift_logits.size(-1)), target=shift_labels.view(-1)).view(shift_logits.shape[0],shift_logits.shape[1]) |
|
mask = encoded_batch['attention_mask'][..., 1:].float() |
|
mask[mask==0]=float('nan') |
|
loss *= mask |
|
loss = nanmean(loss, dim=1) |
|
scores_batch = list(loss.cpu().numpy()) |
|
full_batch_length = len(encoded_batch['input_ids']) |
|
scores['score'] += scores_batch |
|
mutant_index+=full_batch_length |
|
scores = pd.DataFrame(scores) |
|
scores_mutated_seq = scores[scores.mutant != 'wt'] |
|
scores_wt = scores[scores.mutant == 'wt'] |
|
delta_scores = pd.merge(scores_mutated_seq,scores_wt,how='left',on=['window_start'],suffixes=('','_wt')) |
|
delta_scores[score_var_name] = delta_scores['score'] - delta_scores['score_wt'] |
|
delta_scores=delta_scores[['mutant',score_var_name]].groupby('mutant').mean().reset_index() |
|
return delta_scores |
|
|
|
def get_sequence_slices(df, target_seq, model_context_len, start_idx=1, scoring_window="optimal", indel_mode=False): |
|
""" |
|
Helper function that takes as input a (pandas) dataframe df that contains a list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring. |
|
It returns a processed DMS in which sequences have been sliced to satisfy the maximum context window of the model. |
|
df: (dataframe) Input dataframe to be processed |
|
target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay. |
|
model_context_len: (int) Maximum context size for the model. |
|
start_idx: (int) Integer to move to 0-indexing of positions (mutation triplet are typically based on 1-indexing). |
|
scoring_window: (string) Method to slice sequences longer than maximum context size: |
|
- optimal selects a single window as large as possible via the get_optimal_window function (this is the default) |
|
- sliding splits the full sequence in contiguous (non-overlapping) chunks that are of size equal to the max context (except the last chunk which may be shorter) |
|
indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions. |
|
Note: when scoring indels for sequences that would be longer than the model max context length, it is preferable to use the "sliding" scoring_window. Use "optimal" otherwise. |
|
""" |
|
len_target_seq = len(target_seq) |
|
num_mutants = len(df['mutant']) |
|
df=df.reset_index(drop=True) |
|
if scoring_window=="optimal": |
|
df['mutation_barycenter'] = df['mutant'].apply(lambda x: int(np.array([int(mutation[1:-1]) - start_idx for mutation in x.split(':')]).mean())) if not indel_mode else df['mutant'].apply(lambda x: len(x)//2) |
|
df['scoring_optimal_window'] = df['mutation_barycenter'].apply(lambda x: get_optimal_window(x, len_target_seq, model_context_len)) if not indel_mode else df['mutant'].apply(lambda x: (0,len(x))) |
|
df['full_raw_sequence'] = df['mutated_sequence'] |
|
df['mutated_sequence'] = [df['mutated_sequence'][index][df['scoring_optimal_window'][index][0]:df['scoring_optimal_window'][index][1]] for index in range(num_mutants)] |
|
df['window_start'] = df['scoring_optimal_window'].map(lambda x: x[0]) |
|
df['window_end'] = df['scoring_optimal_window'].map(lambda x: x[1]) |
|
del df['scoring_optimal_window'] |
|
df_wt=df.copy() |
|
df_wt['mutant'] = ['wt'] * num_mutants |
|
df_wt['full_raw_sequence'] = [target_seq] * num_mutants |
|
if indel_mode: |
|
df_wt['mutation_barycenter'] = [len_target_seq // 2] * num_mutants |
|
df_wt['window_end'] = df_wt['full_raw_sequence'].map(lambda x:len(x)) |
|
df_wt['mutated_sequence'] = [target_seq[df_wt['window_start'][index]:df_wt['window_end'][index]] for index in range(num_mutants)] |
|
df = pd.concat([df,df_wt], axis=0) |
|
df = df.drop_duplicates() |
|
elif scoring_window=="sliding": |
|
len_target_seq = len(target_seq) |
|
num_windows = 1 + int( len_target_seq / model_context_len) |
|
df_list=[] |
|
start=0 |
|
for window_index in range(1, num_windows+1): |
|
df_sliced = df.copy() |
|
df_sliced['full_raw_sequence'] = df_sliced['mutated_sequence'] |
|
df_sliced['mutated_sequence'] = df_sliced['mutated_sequence'].map(lambda x: x[start:start+model_context_len]) |
|
df_sliced['window_start'] = [start] * num_mutants |
|
df_sliced['window_end'] = df_sliced['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len)) |
|
df_sliced_wt = df_sliced.copy() |
|
df_sliced_wt['mutant'] = ['wt'] * num_mutants |
|
df_sliced_wt['full_raw_sequence'] = [target_seq] * num_mutants |
|
df_sliced_wt['mutated_sequence'] = df_sliced_wt['full_raw_sequence'].map(lambda x: x[start:start+model_context_len]) |
|
df_sliced_wt['window_end'] = df_sliced_wt['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len)) |
|
df_list.append(df_sliced) |
|
df_list.append(df_sliced_wt) |
|
start += model_context_len |
|
df_final = pd.concat(df_list,axis=0) |
|
df = df_final.drop_duplicates() |
|
return df.reset_index(drop=True) |