PascalNotin's picture
Fixed library bug and improved interface messages
5d3f7a9
raw
history blame
10.9 kB
import torch
import transformers
from transformers import PreTrainedTokenizerFast
import tranception
import datasets
from tranception import config, model_pytorch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]"
)
#######################################################################################################################################
############################################### HELPER FUNCTIONS ####################################################################
#######################################################################################################################################
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
all_single_mutants={}
sequence_list=list(sequence)
if mutation_range_start is None: mutation_range_start=1
if mutation_range_end is None: mutation_range_end=len(sequence)
for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]):
for mutated_AA in AA_vocab:
if current_AA!=mutated_AA:
mutated_sequence = sequence_list.copy()
mutated_sequence[position] = mutated_AA
all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence)
all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index')
all_single_mutants.reset_index(inplace=True)
all_single_mutants.columns = ['mutant','mutated_sequence']
return all_single_mutants
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4)
fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20))
scores_dict = {}
valid_mutant_set=set(scores.mutant)
if mutation_range_start is None: mutation_range_start=1
if mutation_range_end is None: mutation_range_start=len(sequence)
for target_AA in list(AA_vocab):
for position in range(mutation_range_start,mutation_range_end+1):
mutant = sequence[position-1]+str(position)+target_AA
if mutant in valid_mutant_set:
scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score'])
else:
scores_dict[mutant]=0.0
labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1)
heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'})
heat.figure.axes[-1].yaxis.label.set_size(20)
#heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30)
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40)
heat.set_xlabel("Sequence position", fontsize = 20)
heat.set_ylabel("Amino Acid mutation", fontsize = 20)
plt.savefig('fitness_scoring_substitution_matrix.png')
return plt
def suggest_mutations(scores):
intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
#Best mutants
top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant)
mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants))
#Best positions
positive_scores = scores[scores.avg_score > 0]
positive_scores_position_avg = positive_scores.groupby(['position']).mean()
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
print(top_positions)
position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions))
return intro_message+mutant_recos+position_recos
def get_mutated_protein(sequence,mutant):
mutated_sequence = list(sequence)
mutated_sequence[int(mutant[1:-1])-1]=mutant[-1]
return ''.join(mutated_sequence)
def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab):
if model_type=="Small":
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small")
elif model_type=="Medium":
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
elif model_type=="Large":
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
model.config.tokenizer = tokenizer
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
scores = model.score_mutants(DMS_data=all_single_mutants,
target_seq=sequence,
scoring_mirror=scoring_mirror,
batch_size_inference=batch_size_inference,
num_workers=num_workers,
indel_mode=False
)
scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end)
return score_heatmap,suggest_mutations(scores)
#######################################################################################################################################
############################################### GRADIO INTERFACE ####################################################################
#######################################################################################################################################
title = "Interactive in silico directed evolution with Tranception"
description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance."
article = "<p style='text-align: center'><a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</a></p>"
examples=[
['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'],
['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'],
['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'],
['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD']
]
model_size_selection = gr.Radio(label="Tranception model size (larger models are more accurate but are slower at inference)", choices=["Small","Medium","Large"], value="Small")
protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK")
mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0)
mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0)
scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)")
#output ==> find a way to make scroallable
output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range")
output_recommendations = gr.Textbox(label="Mutation recommendations")
gr.Interface(
fn=score_and_create_matrix_all_singles,
inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror],
outputs=["plot","text"],
title=title,
description=description,
article=article,
#examples=examples,
allow_flagging="never"
).launch(debug=True)