File size: 11,921 Bytes
fa87656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import tqdm
import re
import numpy as np
import pandas as pd

import torch
from torch.nn import CrossEntropyLoss, NLLLoss
from torch.utils.data.sampler import Sampler, SequentialSampler

from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerFast
from datasets import Dataset

AA_vocab = "ACDEFGHIKLMNPQRSTVWY"

def get_mutated_sequence(focus_seq, mutant, start_idx=1, AA_vocab=AA_vocab):
    """
    Helper function that mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only).
    Mutation triplet are typically based on 1-indexing: start_idx is used for switching to 0-indexing.
    """
    mutated_seq = list(focus_seq)
    for mutation in mutant.split(":"):
        try:
            from_AA, position, to_AA = mutation[0], int(mutation[1:-1]), mutation[-1]
        except:
            print("Issue with mutant: "+str(mutation))
        relative_position = position - start_idx
        assert (from_AA==focus_seq[relative_position]), "Invalid from_AA or mutant position: "+str(mutation)+" from_AA: "+str(from_AA) + " relative pos: "+str(relative_position) + " focus_seq: "+str(focus_seq)
        assert (to_AA in AA_vocab) , "Mutant to_AA is invalid: "+str(mutation)
        mutated_seq[relative_position] = to_AA
    return "".join(mutated_seq)

def nanmean(v, *args, inplace=False, **kwargs):
    if not inplace:
        v = v.clone()
    is_nan = torch.isnan(v)
    v[is_nan] = 0
    return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)

def nansum(v, *args, inplace=False, **kwargs):
    if not inplace:
        v = v.clone()
    is_nan = torch.isnan(v)
    v[is_nan] = 0
    return v.sum(*args, **kwargs)

def get_optimal_window(mutation_position_relative, seq_len_wo_special, model_window):
    """
    Helper function that selects an optimal sequence window that fits the maximum model context size.
    If the sequence length is less than the maximum context size, the full sequence is returned.
    """
    half_model_window = model_window // 2
    if seq_len_wo_special <= model_window:
        return [0,seq_len_wo_special]
    elif mutation_position_relative < half_model_window:
        return [0,model_window]
    elif mutation_position_relative >= seq_len_wo_special - half_model_window:
        return [seq_len_wo_special - model_window, seq_len_wo_special]
    else:
        return [max(0,mutation_position_relative-half_model_window), min(seq_len_wo_special,mutation_position_relative+half_model_window)]

def sequence_replace_single(sequence, char_to_replace, char_replacements):
    char_replacements = list(char_replacements)
    positions = [m.start() for m in re.finditer(char_to_replace, sequence)]
    replacements = np.random.choice(a=char_replacements, size=len(positions), replace=True)
    sequence=list(sequence)
    for idx, position in enumerate(positions):
        sequence[position]=replacements[idx]
    return ''.join(sequence)

def sequence_replace(sequences, char_to_replace, char_replacements):
    """
    Helper function that replaces all Amino Acids passsed in via char_to_replace (as a string of AAs) with Amino Acids sampled from char_replacements (also a string of eligible AAs).
    """
    return [sequence_replace_single(sequence, char_to_replace, char_replacements) for sequence in sequences]

def get_tranception_scores_mutated_sequences(model, mutated_sequence_df, batch_size_inference, score_var_name, len_target_seq, num_workers=10, reverse=False, indel_mode=False):
    """
    Helper function that takes as input a set of mutated sequences (in a pandas dataframe) and returns scores for each mutation (delta log likelihood wrt wild type sequence).
    """
    scores = {} 
    scores['mutant']=[]
    scores['window_start']=[]
    scores['window_end']=[]
    scores['score']=[]
    with torch.no_grad():
        ds = Dataset.from_pandas(mutated_sequence_df)
        ds.set_transform(model.encode_batch)
        data_collator = DataCollatorForLanguageModeling(
                        tokenizer=model.config.tokenizer,
                        mlm=False)
        sampler = SequentialSampler(ds)
        ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size_inference, sampler=sampler, collate_fn=data_collator, num_workers=num_workers, pin_memory=True, drop_last=False)
        mutant_index=0
        for encoded_batch in tqdm.tqdm(ds_loader):
            full_batch_length = len(encoded_batch['input_ids'])
            scores['mutant'] += list(mutated_sequence_df['mutant'][mutant_index:mutant_index+full_batch_length])
            window_start = np.array(mutated_sequence_df['window_start'][mutant_index:mutant_index+full_batch_length])
            scores['window_start'] += list(window_start)
            window_end = np.array(mutated_sequence_df['window_end'][mutant_index:mutant_index+full_batch_length])
            scores['window_end'] += list(window_end)
            full_raw_sequence = np.array(mutated_sequence_df['full_raw_sequence'][mutant_index:mutant_index+full_batch_length])
            for k, v in encoded_batch.items():
                if isinstance(v, torch.Tensor):
                    encoded_batch[k] = v.to(model.device)            
            shift_labels = encoded_batch['labels'][..., 1:].contiguous()
            if (hasattr(model.config,"retrieval_aggregation_mode")) and (model.config.retrieval_aggregation_mode is not None):
                if reverse:
                    encoded_batch['flip']=torch.tensor([1]*full_batch_length)
                encoded_batch['start_slice']=window_start
                encoded_batch['end_slice']=window_end
                encoded_batch['full_raw_sequence'] = full_raw_sequence #only mutated_sequence is flipped if the scoring_mirror branch of score_mutants. No need to flip full_raw_sequence for MSA re-aligning
                fused_shift_log_probas=model(**encoded_batch,return_dict=True).fused_shift_log_probas
                loss_fct = NLLLoss(reduction='none')
                loss = - loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
            else:
                lm_logits=model(**encoded_batch,return_dict=True).logits
                shift_logits = lm_logits[..., :-1, :].contiguous()
                loss_fct = CrossEntropyLoss(reduction='none') 
                loss = - loss_fct(input=shift_logits.view(-1, shift_logits.size(-1)), target=shift_labels.view(-1)).view(shift_logits.shape[0],shift_logits.shape[1])
            mask = encoded_batch['attention_mask'][..., 1:].float()
            mask[mask==0]=float('nan')
            loss *= mask
            loss =  nanmean(loss, dim=1)
            scores_batch = list(loss.cpu().numpy())
            full_batch_length = len(encoded_batch['input_ids'])
            scores['score'] += scores_batch
            mutant_index+=full_batch_length
    scores = pd.DataFrame(scores)
    scores_mutated_seq = scores[scores.mutant != 'wt']
    scores_wt = scores[scores.mutant == 'wt']
    delta_scores = pd.merge(scores_mutated_seq,scores_wt,how='left',on=['window_start'],suffixes=('','_wt'))
    delta_scores[score_var_name]  = delta_scores['score'] - delta_scores['score_wt']
    delta_scores=delta_scores[['mutant',score_var_name]].groupby('mutant').mean().reset_index()
    return delta_scores

def get_sequence_slices(df, target_seq, model_context_len, start_idx=1, scoring_window="optimal", indel_mode=False):
    """
    Helper function that takes as input a (pandas) dataframe df that contains a list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
    It returns a processed DMS in which sequences have been sliced to satisfy the maximum context window of the model.
    df: (dataframe) Input dataframe to be processed
    target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
    model_context_len: (int) Maximum context size for the model.
    start_idx: (int) Integer to move to 0-indexing of positions (mutation triplet are typically based on 1-indexing).
    scoring_window: (string) Method to slice sequences longer than maximum context size: 
        - optimal selects a single window as large as possible via the get_optimal_window function (this is the default)
        - sliding splits the full sequence in contiguous (non-overlapping) chunks that are of size equal to the max context (except the last chunk which may be shorter)
    indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
    Note: when scoring indels for sequences that would be longer than the model max context length, it is preferable to use the "sliding" scoring_window. Use "optimal" otherwise.
    """
    len_target_seq = len(target_seq)
    num_mutants = len(df['mutant'])
    df=df.reset_index(drop=True)
    if scoring_window=="optimal":
        df['mutation_barycenter'] = df['mutant'].apply(lambda x: int(np.array([int(mutation[1:-1]) - start_idx for mutation in x.split(':')]).mean())) if not indel_mode else df['mutant'].apply(lambda x: len(x)//2)
        df['scoring_optimal_window'] = df['mutation_barycenter'].apply(lambda x: get_optimal_window(x, len_target_seq, model_context_len)) if not indel_mode else df['mutant'].apply(lambda x: (0,len(x)))
        df['full_raw_sequence'] = df['mutated_sequence']
        df['mutated_sequence'] = [df['mutated_sequence'][index][df['scoring_optimal_window'][index][0]:df['scoring_optimal_window'][index][1]] for index in range(num_mutants)]
        df['window_start'] = df['scoring_optimal_window'].map(lambda x: x[0]) 
        df['window_end'] = df['scoring_optimal_window'].map(lambda x: x[1])
        del df['scoring_optimal_window']
        df_wt=df.copy()
        df_wt['mutant'] = ['wt'] * num_mutants
        df_wt['full_raw_sequence'] = [target_seq] * num_mutants
        if indel_mode: # For indels, we set the wild type reference to be always the same (full length) sequence. We assume here that the length is lower than model context size (otherwise use "Sliding")
            df_wt['mutation_barycenter'] = [len_target_seq // 2] * num_mutants
            df_wt['window_end'] = df_wt['full_raw_sequence'].map(lambda x:len(x))
        df_wt['mutated_sequence'] = [target_seq[df_wt['window_start'][index]:df_wt['window_end'][index]] for index in range(num_mutants)]
        df = pd.concat([df,df_wt], axis=0)
        df = df.drop_duplicates()
    elif scoring_window=="sliding":
        len_target_seq = len(target_seq)
        num_windows = 1 + int( len_target_seq / model_context_len)
        df_list=[]
        start=0
        for window_index in range(1, num_windows+1):
            df_sliced = df.copy()
            df_sliced['full_raw_sequence'] = df_sliced['mutated_sequence']
            df_sliced['mutated_sequence'] = df_sliced['mutated_sequence'].map(lambda x: x[start:start+model_context_len]) 
            df_sliced['window_start'] = [start] * num_mutants 
            df_sliced['window_end']  =  df_sliced['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len)) 
            df_sliced_wt = df_sliced.copy()
            df_sliced_wt['mutant'] = ['wt'] * num_mutants
            df_sliced_wt['full_raw_sequence'] = [target_seq] * num_mutants
            df_sliced_wt['mutated_sequence'] = df_sliced_wt['full_raw_sequence'].map(lambda x: x[start:start+model_context_len])
            df_sliced_wt['window_end'] = df_sliced_wt['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len)) #Need to adjust end index if WT and sequence are not same full length
            df_list.append(df_sliced)
            df_list.append(df_sliced_wt)
            start += model_context_len
        df_final = pd.concat(df_list,axis=0)
        df = df_final.drop_duplicates()
    return df.reset_index(drop=True)