Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import HfApi, ModelFilter | |
import pandas as pd | |
from re import match | |
from tempfile import NamedTemporaryFile | |
import torch | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
# fetch suitable ESM models from HuggingFace Hub | |
MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)] | |
if not any(MODELS): | |
raise RuntimeError("Error while retrieving models from HuggingFace Hub") | |
# scoring strategies | |
SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"] | |
class Model: | |
"""Wrapper for ESM models""" | |
def __init__(self, model_name:str=""): | |
"load selected model and tokenizer" | |
self.model_name = model_name | |
if model_name: | |
self.model = AutoModelForMaskedLM.from_pretrained(model_name) | |
self.batch_converter = AutoTokenizer.from_pretrained(model_name) | |
self.alphabet = self.batch_converter.get_vocab() | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor: | |
"run model on batch of tokens" | |
return self.model(batch_tokens)["logits"] | |
def __lshift__(self, input:str) -> torch.Tensor: | |
"convert input string to batch of tokens" | |
return self.batch_converter(input, return_tensors="pt")["input_ids"] | |
def __getitem__(self, key:str) -> int: | |
"get token ID from character" | |
return self.alphabet[key] | |
def run_model(self, data): | |
"run model on data" | |
def label_row(row, token_probs): | |
"label row with score" | |
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] | |
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] | |
return score.item() | |
batch_tokens = self<<data.seq | |
# run model with selected scoring strategy (info thereof available in the original ESM paper) | |
if data.scoring_strategy.startswith("wt-marginals"): | |
with torch.no_grad(): | |
token_probs = torch.log_softmax(self>>batch_tokens, dim=-1) | |
data.out[self.model_name] = data.sub.apply( | |
lambda row: label_row( | |
row['0'], | |
token_probs, | |
), | |
axis=1, | |
) | |
elif data.scoring_strategy.startswith("masked-marginals"): | |
all_token_probs = [] | |
for i in range(batch_tokens.size()[1]): | |
batch_tokens_masked = batch_tokens.clone() | |
batch_tokens_masked[0, i] = self['<mask>'] | |
with torch.no_grad(): | |
token_probs = torch.log_softmax( | |
self>>batch_tokens_masked, dim=-1 | |
) | |
all_token_probs.append(token_probs[:, i]) | |
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) | |
data.out[self.model_name] = data.sub.apply( | |
lambda row: label_row( | |
row['0'], | |
token_probs, | |
), | |
axis=1, | |
) | |
class Data: | |
"""Container for input and output data""" | |
# initialise empty model as static class member for efficiency | |
model = Model() | |
def parse_seq(self, src:str): | |
"parse input sequence" | |
self.seq = src.strip().upper() | |
if not all(x in self.model.alphabet for x in src): | |
raise RuntimeError("Unrecognised characters in sequence") | |
def parse_sub(self, trg:str): | |
"parse input substitutions" | |
self.mode = None | |
self.sub = list() | |
self.trg = trg.strip().upper() | |
# identify running mode | |
if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq): # if single string of same length as sequence, seq vs seq mode | |
self.mode = 'SVS' | |
for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1): | |
if src != trg: | |
self.sub.append(f"{src}{resi}{trg}") | |
else: | |
self.trg = self.trg.split() | |
if all(match(r'\d+', x) for x in self.trg): # if all strings are numbers, deep mutational scanning mode | |
self.mode = 'DMS' | |
for resi in map(int, self.trg): | |
src = self.seq[resi-1] | |
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''): | |
self.sub.append(f"{src}{resi}{trg}") | |
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # if all strings are of the form X#Y, single substitution mode | |
self.mode = 'MUT' | |
self.sub = self.trg | |
else: | |
raise RuntimeError("Unrecognised running mode; wrong inputs?") | |
self.sub = pd.DataFrame(self.sub, columns=['0']) | |
def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file): | |
"initialise data" | |
# if model has changed, load new model | |
if self.model.model_name != model_name: | |
self.model_name = model_name | |
self.model = Model(model_name) | |
self.parse_seq(src) | |
self.parse_sub(trg) | |
self.scoring_strategy = scoring_strategy | |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name]) | |
self.out_buffer = out_file.name | |
def parse_output(self) -> str: | |
"format output data for visualisation" | |
if self.mode == 'MUT': # if single substitution mode, sort by score | |
self.out = self.out.sort_values(self.model_name, ascending=False) | |
elif self.mode == 'DMS': # if deep mutational scanning mode, sort by residue and score | |
self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) # FIX: this doesn't work if there's jolly characters in the input sequence | |
.sort_values(['resi', self.model_name], ascending=[True,False]) | |
.groupby(['resi']) | |
.head(19) | |
.drop(['resi'], axis=1)).iloc[19*x:19*(x+1)] | |
.reset_index(drop=True) for x in range(self.out.shape[0]//19)] | |
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns') | |
# save to temporary file to be downloaded | |
self.out.round(2).to_csv(self.out_buffer, index=False) | |
return (self.out.style | |
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x) | |
.hide(axis=0) | |
.hide(axis=1) | |
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8) | |
.to_html()) | |
def calculate(self): | |
"run model and parse output" | |
self.model.run_model(self) | |
return self, self.parse_output() | |
def app(*argv): | |
"run app" | |
seq, trg, model_name, scoring_strategy, out_file, *_ = argv | |
data, html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate() | |
return html, gr.File.update(value=out_file.name, visible=True) | |
# df = pd.DataFrame((pd.np.random.random((10, 5))-0.5)*10, columns=list('ABCDE')) | |
# df.to_csv(out_file.name, index=False) | |
# return df.to_html(), gr.File.update(value=out_file.name, visible=True) | |
with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md: | |
gr.Markdown(md.read()) | |
seq = gr.Textbox(lines=2, label="Sequence", placeholder="Sequence here...", value='MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ') | |
trg = gr.Textbox(lines=1, label="Substitutions", placeholder="Substitutions here...", value="61 214 19 30 122 140") | |
model_name = gr.Dropdown(MODELS, label="Model", value=MODELS[1]) | |
scoring_strategy = gr.Dropdown(SCORING, label="Scoring strategy", value=SCORING[1]) | |
btn = gr.Button(value="Submit") | |
out = gr.HTML() | |
bto = gr.File(value=out_file.name, visible=False, label="Download", file_count='single', interactive=False) | |
btn.click(fn=app, inputs=[seq, trg, model_name, scoring_strategy, bto], outputs=[out, bto]) | |
# demo.launch(share=True, server_name="0.0.0.0", server_port=7878) |