|
import os |
|
import torch |
|
import torch.nn as nn |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral |
|
from lavis.models.base_model import FAPMConfig |
|
import spaces |
|
import gradio as gr |
|
|
|
from esm import pretrained, FastaBatchedDataset |
|
from data.evaluate_data.utils import Ontology |
|
import difflib |
|
import re |
|
from transformers import MistralForCausalLM |
|
|
|
|
|
def get_model(type='Molecule Function'): |
|
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b') |
|
if type == 'Molecule Function': |
|
model.load_checkpoint("model/checkpoint_mf2.pth") |
|
model.to('cuda') |
|
elif type == 'Biological Process': |
|
model.load_checkpoint("model/checkpoint_bp1.pth") |
|
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu')) |
|
model.to('cuda') |
|
elif type == 'Cellar Component': |
|
model.load_checkpoint("model/checkpoint_cc2.pth") |
|
model.to('cuda') |
|
return model |
|
|
|
|
|
models = { |
|
'Molecule Function': get_model('Molecule Function'), |
|
'Biological Process': get_model('Biological Process'), |
|
'Cellular Component': get_model('Cellar Component'), |
|
} |
|
|
|
|
|
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16) |
|
mistral_model.to('cuda') |
|
|
|
|
|
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D') |
|
model_esm.to('cuda') |
|
model_esm.eval() |
|
|
|
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True) |
|
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None) |
|
go_des.columns = ['id', 'text'] |
|
go_des = go_des.dropna() |
|
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x)) |
|
go_obo_set = set(go_des['id'].tolist()) |
|
go_des['text'] = go_des['text'].apply(lambda x: x.lower()) |
|
GO_dict = dict(zip(go_des['text'], go_des['id'])) |
|
Func_dict = dict(zip(go_des['id'], go_des['text'])) |
|
|
|
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl') |
|
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))] |
|
choices_mf = {x.lower(): x for x in choices_mf} |
|
terms_bp = pd.read_pickle('data/terms/bp_terms.pkl') |
|
choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))] |
|
choices_bp = {x.lower(): x for x in choices_bp} |
|
terms_cc = pd.read_pickle('data/terms/cc_terms.pkl') |
|
choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))] |
|
choices_cc = {x.lower(): x for x in choices_cc} |
|
choices = { |
|
'Molecule Function': choices_mf, |
|
'Biological Process': choices_bp, |
|
'Cellular Component': choices_cc, |
|
} |
|
|
|
@spaces.GPU |
|
def generate_caption(protein, prompt): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protein_name = 'protein_name' |
|
protein_seq = protein |
|
include = 'per_tok' |
|
repr_layers = [36] |
|
truncation_seq_length = 1024 |
|
toks_per_batch = 4096 |
|
|
|
dataset = FastaBatchedDataset([protein_name], [protein_seq]) |
|
|
|
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) |
|
|
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches |
|
) |
|
|
|
return_contacts = "contacts" in include |
|
|
|
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers) |
|
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers] |
|
|
|
with torch.no_grad(): |
|
for batch_idx, (labels, strs, toks) in enumerate(data_loader): |
|
print( |
|
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" |
|
) |
|
if torch.cuda.is_available(): |
|
toks = toks.to(device="cuda", non_blocking=True) |
|
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts) |
|
representations = { |
|
layer: t.to(device="cpu") for layer, t in out["representations"].items() |
|
} |
|
if return_contacts: |
|
contacts = out["contacts"].to(device="cpu") |
|
for i, label in enumerate(labels): |
|
result = {"label": label} |
|
truncate_len = min(truncation_seq_length, len(strs[i])) |
|
|
|
|
|
if "per_tok" in include: |
|
result["representations"] = { |
|
layer: t[i, 1: truncate_len + 1].clone() |
|
for layer, t in representations.items() |
|
} |
|
if "mean" in include: |
|
result["mean_representations"] = { |
|
layer: t[i, 1: truncate_len + 1].mean(0).clone() |
|
for layer, t in representations.items() |
|
} |
|
if "bos" in include: |
|
result["bos_representations"] = { |
|
layer: t[i, 0].clone() for layer, t in representations.items() |
|
} |
|
if return_contacts: |
|
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() |
|
esm_emb = result['representations'][36] |
|
''' |
|
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda') |
|
with torch.no_grad(): |
|
outputs = model_esm(**inputs) |
|
esm_emb = outputs.last_hidden_state.detach()[0] |
|
''' |
|
|
|
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda') |
|
if prompt is None: |
|
prompt = 'none' |
|
else: |
|
prompt = prompt.lower() |
|
samples = {'name': ['protein_name'], |
|
'image': torch.unsqueeze(esm_emb, dim=0), |
|
'text_input': ['none'], |
|
'prompt': [prompt]} |
|
|
|
union_pred_terms = [] |
|
for model_id in models.keys(): |
|
model = models[model_id] |
|
|
|
prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., |
|
repetition_penalty=1.0) |
|
x = prediction[0] |
|
x = [eval(i) for i in x.split('; ')] |
|
pred_terms = [] |
|
temp = [] |
|
for i in x: |
|
txt = i[0] |
|
prob = i[1] |
|
sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9) |
|
if len(sim_list) > 0: |
|
t_standard = sim_list[0] |
|
if t_standard not in temp: |
|
pred_terms.append(t_standard+f'({prob})') |
|
temp.append(t_standard) |
|
union_pred_terms.append(pred_terms) |
|
|
|
if prompt == 'none': |
|
res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!" |
|
else: |
|
res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!" |
|
if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0: |
|
return res_str |
|
res_str = '' |
|
if len(union_pred_terms[0]) != 0: |
|
temp = ['- '+i+'\n' for i in union_pred_terms[0]] |
|
res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of \n{''.join(temp)} \n" |
|
if len(union_pred_terms[1]) != 0: |
|
temp = ['- ' + i + '\n' for i in union_pred_terms[1]] |
|
res_str += f"It is likely involved in the following process: \n{''.join(temp)} \n" |
|
if len(union_pred_terms[2]) != 0: |
|
temp = ['- ' + i + '\n' for i in union_pred_terms[2]] |
|
res_str += f"It's subcellular localization is within the: \n{''.join(temp)}" |
|
return res_str |
|
|
|
|
|
|
|
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information. |
|
|
|
Our paper is available at [BioRxiv](https://www.biorxiv.org/content/10.1101/2024.05.07.593067v1) |
|
|
|
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main). |
|
|
|
Thanks for the support from ProtonUnfold Tech. Co., Ltd (https://www.protonunfold.com/).""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
#output { |
|
height: 500px; |
|
overflow: auto; |
|
border: 1px solid #ccc; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(description) |
|
with gr.Tab(label="Protein caption"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_protein = gr.Textbox(type="text", label="Upload sequence") |
|
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)") |
|
submit_btn = gr.Button(value="Submit") |
|
with gr.Column(): |
|
|
|
with gr.Accordion('Prediction:', open=True): |
|
output_markdown = gr.Markdown(label="Output") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''], |
|
["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''], |
|
["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'], |
|
['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'], |
|
['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'], |
|
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'], |
|
], |
|
inputs=[input_protein, prompt], |
|
outputs=[output_markdown], |
|
fn=generate_caption, |
|
cache_examples=True, |
|
label='Try examples' |
|
) |
|
submit_btn.click(generate_caption, [input_protein, prompt], [output_markdown]) |
|
|
|
demo.launch(debug=True) |
|
|
|
|