Tranception_Large / utils /scoring_utils.py
PascalNotin's picture
Added Tranception model
fa87656
raw history blame
No virus
11.9 kB
import os
import tqdm
import re
import numpy as np
import pandas as pd
import torch
from torch.nn import CrossEntropyLoss, NLLLoss
from torch.utils.data.sampler import Sampler, SequentialSampler
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerFast
from datasets import Dataset
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
def get_mutated_sequence(focus_seq, mutant, start_idx=1, AA_vocab=AA_vocab):
"""
Helper function that mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only).
Mutation triplet are typically based on 1-indexing: start_idx is used for switching to 0-indexing.
"""
mutated_seq = list(focus_seq)
for mutation in mutant.split(":"):
try:
from_AA, position, to_AA = mutation[0], int(mutation[1:-1]), mutation[-1]
except:
print("Issue with mutant: "+str(mutation))
relative_position = position - start_idx
assert (from_AA==focus_seq[relative_position]), "Invalid from_AA or mutant position: "+str(mutation)+" from_AA: "+str(from_AA) + " relative pos: "+str(relative_position) + " focus_seq: "+str(focus_seq)
assert (to_AA in AA_vocab) , "Mutant to_AA is invalid: "+str(mutation)
mutated_seq[relative_position] = to_AA
return "".join(mutated_seq)
def nanmean(v, *args, inplace=False, **kwargs):
if not inplace:
v = v.clone()
is_nan = torch.isnan(v)
v[is_nan] = 0
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
def nansum(v, *args, inplace=False, **kwargs):
if not inplace:
v = v.clone()
is_nan = torch.isnan(v)
v[is_nan] = 0
return v.sum(*args, **kwargs)
def get_optimal_window(mutation_position_relative, seq_len_wo_special, model_window):
"""
Helper function that selects an optimal sequence window that fits the maximum model context size.
If the sequence length is less than the maximum context size, the full sequence is returned.
"""
half_model_window = model_window // 2
if seq_len_wo_special <= model_window:
return [0,seq_len_wo_special]
elif mutation_position_relative < half_model_window:
return [0,model_window]
elif mutation_position_relative >= seq_len_wo_special - half_model_window:
return [seq_len_wo_special - model_window, seq_len_wo_special]
else:
return [max(0,mutation_position_relative-half_model_window), min(seq_len_wo_special,mutation_position_relative+half_model_window)]
def sequence_replace_single(sequence, char_to_replace, char_replacements):
char_replacements = list(char_replacements)
positions = [m.start() for m in re.finditer(char_to_replace, sequence)]
replacements = np.random.choice(a=char_replacements, size=len(positions), replace=True)
sequence=list(sequence)
for idx, position in enumerate(positions):
sequence[position]=replacements[idx]
return ''.join(sequence)
def sequence_replace(sequences, char_to_replace, char_replacements):
"""
Helper function that replaces all Amino Acids passsed in via char_to_replace (as a string of AAs) with Amino Acids sampled from char_replacements (also a string of eligible AAs).
"""
return [sequence_replace_single(sequence, char_to_replace, char_replacements) for sequence in sequences]
def get_tranception_scores_mutated_sequences(model, mutated_sequence_df, batch_size_inference, score_var_name, len_target_seq, num_workers=10, reverse=False, indel_mode=False):
"""
Helper function that takes as input a set of mutated sequences (in a pandas dataframe) and returns scores for each mutation (delta log likelihood wrt wild type sequence).
"""
scores = {}
scores['mutant']=[]
scores['window_start']=[]
scores['window_end']=[]
scores['score']=[]
with torch.no_grad():
ds = Dataset.from_pandas(mutated_sequence_df)
ds.set_transform(model.encode_batch)
data_collator = DataCollatorForLanguageModeling(
tokenizer=model.config.tokenizer,
mlm=False)
sampler = SequentialSampler(ds)
ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size_inference, sampler=sampler, collate_fn=data_collator, num_workers=num_workers, pin_memory=True, drop_last=False)
mutant_index=0
for encoded_batch in tqdm.tqdm(ds_loader):
full_batch_length = len(encoded_batch['input_ids'])
scores['mutant'] += list(mutated_sequence_df['mutant'][mutant_index:mutant_index+full_batch_length])
window_start = np.array(mutated_sequence_df['window_start'][mutant_index:mutant_index+full_batch_length])
scores['window_start'] += list(window_start)
window_end = np.array(mutated_sequence_df['window_end'][mutant_index:mutant_index+full_batch_length])
scores['window_end'] += list(window_end)
full_raw_sequence = np.array(mutated_sequence_df['full_raw_sequence'][mutant_index:mutant_index+full_batch_length])
for k, v in encoded_batch.items():
if isinstance(v, torch.Tensor):
encoded_batch[k] = v.to(model.device)
shift_labels = encoded_batch['labels'][..., 1:].contiguous()
if (hasattr(model.config,"retrieval_aggregation_mode")) and (model.config.retrieval_aggregation_mode is not None):
if reverse:
encoded_batch['flip']=torch.tensor([1]*full_batch_length)
encoded_batch['start_slice']=window_start
encoded_batch['end_slice']=window_end
encoded_batch['full_raw_sequence'] = full_raw_sequence #only mutated_sequence is flipped if the scoring_mirror branch of score_mutants. No need to flip full_raw_sequence for MSA re-aligning
fused_shift_log_probas=model(**encoded_batch,return_dict=True).fused_shift_log_probas
loss_fct = NLLLoss(reduction='none')
loss = - loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
else:
lm_logits=model(**encoded_batch,return_dict=True).logits
shift_logits = lm_logits[..., :-1, :].contiguous()
loss_fct = CrossEntropyLoss(reduction='none')
loss = - loss_fct(input=shift_logits.view(-1, shift_logits.size(-1)), target=shift_labels.view(-1)).view(shift_logits.shape[0],shift_logits.shape[1])
mask = encoded_batch['attention_mask'][..., 1:].float()
mask[mask==0]=float('nan')
loss *= mask
loss = nanmean(loss, dim=1)
scores_batch = list(loss.cpu().numpy())
full_batch_length = len(encoded_batch['input_ids'])
scores['score'] += scores_batch
mutant_index+=full_batch_length
scores = pd.DataFrame(scores)
scores_mutated_seq = scores[scores.mutant != 'wt']
scores_wt = scores[scores.mutant == 'wt']
delta_scores = pd.merge(scores_mutated_seq,scores_wt,how='left',on=['window_start'],suffixes=('','_wt'))
delta_scores[score_var_name] = delta_scores['score'] - delta_scores['score_wt']
delta_scores=delta_scores[['mutant',score_var_name]].groupby('mutant').mean().reset_index()
return delta_scores
def get_sequence_slices(df, target_seq, model_context_len, start_idx=1, scoring_window="optimal", indel_mode=False):
"""
Helper function that takes as input a (pandas) dataframe df that contains a list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
It returns a processed DMS in which sequences have been sliced to satisfy the maximum context window of the model.
df: (dataframe) Input dataframe to be processed
target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
model_context_len: (int) Maximum context size for the model.
start_idx: (int) Integer to move to 0-indexing of positions (mutation triplet are typically based on 1-indexing).
scoring_window: (string) Method to slice sequences longer than maximum context size:
- optimal selects a single window as large as possible via the get_optimal_window function (this is the default)
- sliding splits the full sequence in contiguous (non-overlapping) chunks that are of size equal to the max context (except the last chunk which may be shorter)
indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
Note: when scoring indels for sequences that would be longer than the model max context length, it is preferable to use the "sliding" scoring_window. Use "optimal" otherwise.
"""
len_target_seq = len(target_seq)
num_mutants = len(df['mutant'])
df=df.reset_index(drop=True)
if scoring_window=="optimal":
df['mutation_barycenter'] = df['mutant'].apply(lambda x: int(np.array([int(mutation[1:-1]) - start_idx for mutation in x.split(':')]).mean())) if not indel_mode else df['mutant'].apply(lambda x: len(x)//2)
df['scoring_optimal_window'] = df['mutation_barycenter'].apply(lambda x: get_optimal_window(x, len_target_seq, model_context_len)) if not indel_mode else df['mutant'].apply(lambda x: (0,len(x)))
df['full_raw_sequence'] = df['mutated_sequence']
df['mutated_sequence'] = [df['mutated_sequence'][index][df['scoring_optimal_window'][index][0]:df['scoring_optimal_window'][index][1]] for index in range(num_mutants)]
df['window_start'] = df['scoring_optimal_window'].map(lambda x: x[0])
df['window_end'] = df['scoring_optimal_window'].map(lambda x: x[1])
del df['scoring_optimal_window']
df_wt=df.copy()
df_wt['mutant'] = ['wt'] * num_mutants
df_wt['full_raw_sequence'] = [target_seq] * num_mutants
if indel_mode: # For indels, we set the wild type reference to be always the same (full length) sequence. We assume here that the length is lower than model context size (otherwise use "Sliding")
df_wt['mutation_barycenter'] = [len_target_seq // 2] * num_mutants
df_wt['window_end'] = df_wt['full_raw_sequence'].map(lambda x:len(x))
df_wt['mutated_sequence'] = [target_seq[df_wt['window_start'][index]:df_wt['window_end'][index]] for index in range(num_mutants)]
df = pd.concat([df,df_wt], axis=0)
df = df.drop_duplicates()
elif scoring_window=="sliding":
len_target_seq = len(target_seq)
num_windows = 1 + int( len_target_seq / model_context_len)
df_list=[]
start=0
for window_index in range(1, num_windows+1):
df_sliced = df.copy()
df_sliced['full_raw_sequence'] = df_sliced['mutated_sequence']
df_sliced['mutated_sequence'] = df_sliced['mutated_sequence'].map(lambda x: x[start:start+model_context_len])
df_sliced['window_start'] = [start] * num_mutants
df_sliced['window_end'] = df_sliced['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len))
df_sliced_wt = df_sliced.copy()
df_sliced_wt['mutant'] = ['wt'] * num_mutants
df_sliced_wt['full_raw_sequence'] = [target_seq] * num_mutants
df_sliced_wt['mutated_sequence'] = df_sliced_wt['full_raw_sequence'].map(lambda x: x[start:start+model_context_len])
df_sliced_wt['window_end'] = df_sliced_wt['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len)) #Need to adjust end index if WT and sequence are not same full length
df_list.append(df_sliced)
df_list.append(df_sliced_wt)
start += model_context_len
df_final = pd.concat(df_list,axis=0)
df = df_final.drop_duplicates()
return df.reset_index(drop=True)