import os import torch import torch.nn as nn import pandas as pd import torch.nn.functional as F from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral from lavis.models.base_model import FAPMConfig import spaces import gradio as gr from esm import pretrained, FastaBatchedDataset from data.evaluate_data.utils import Ontology import difflib import re # Load the model model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b') model.load_checkpoint("model/checkpoint_mf2.pth") model.to('cuda') model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D') model_esm.to('cuda') model_esm.eval() godb = Ontology(f'data/go1.4-basic.obo', with_rels=True) go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None) go_des.columns = ['id', 'text'] go_des = go_des.dropna() go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x)) go_obo_set = set(go_des['id'].tolist()) go_des['text'] = go_des['text'].apply(lambda x: x.lower()) GO_dict = dict(zip(go_des['text'], go_des['id'])) Func_dict = dict(zip(go_des['id'], go_des['text'])) # terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl') terms_mf = pd.read_pickle('data/terms/mf_terms.pkl') choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))] choices = {x.lower(): x for x in choices_mf} @spaces.GPU def generate_caption(protein, prompt): # Process the image and the prompt # with open('/home/user/app/example.fasta', 'w') as f: # f.write('>{}\n'.format("protein_name")) # f.write('{}\n'.format(protein.strip())) # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok") # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein, # model=model_esm, alphabet=alphabet, # include='per_tok', repr_layers=[36], truncation_seq_length=1024) protein_name = 'protein_name' protein_seq = protein include = 'per_tok' repr_layers = [36] truncation_seq_length = 1024 toks_per_batch = 4096 print("start") dataset = FastaBatchedDataset([protein_name], [protein_seq]) print("dataset prepared") batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) print("batches prepared") data_loader = torch.utils.data.DataLoader( dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches ) print(f"Read sequences") return_contacts = "contacts" in include assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers) repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers] with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(data_loader): print( f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" ) if torch.cuda.is_available(): toks = toks.to(device="cuda", non_blocking=True) out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts) representations = { layer: t.to(device="cpu") for layer, t in out["representations"].items() } if return_contacts: contacts = out["contacts"].to(device="cpu") for i, label in enumerate(labels): result = {"label": label} truncate_len = min(truncation_seq_length, len(strs[i])) # Call clone on tensors to ensure tensors are not views into a larger representation # See https://github.com/pytorch/pytorch/issues/1995 if "per_tok" in include: result["representations"] = { layer: t[i, 1: truncate_len + 1].clone() for layer, t in representations.items() } if "mean" in include: result["mean_representations"] = { layer: t[i, 1: truncate_len + 1].mean(0).clone() for layer, t in representations.items() } if "bos" in include: result["bos_representations"] = { layer: t[i, 0].clone() for layer, t in representations.items() } if return_contacts: result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() esm_emb = result['representations'][36] ''' inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda') with torch.no_grad(): outputs = model_esm(**inputs) esm_emb = outputs.last_hidden_state.detach()[0] ''' print("esm embedding generated") esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda') if prompt is None: prompt = 'none' else: prompt = prompt.lower() samples = {'name': ['protein_name'], 'image': torch.unsqueeze(esm_emb, dim=0), 'text_input': ['none'], 'prompt': [prompt]} # Generate the output prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0) x = prediction[0] x = [eval(i) for i in x.split('; ')] pred_terms = [] temp = [] for i in x: txt = i[0] prob = i[1] sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9) if len(sim_list) > 0: t_standard = sim_list[0] if t_standard not in temp: pred_terms.append(t_standard+f'({prob})') temp.append(t_standard) if prompt == 'none': res_str = "No available predictions for this protein, you can try to remove prompt!" else: res_str = "No available predictions for this protein, you can try another protein sequence!" if len(pred_terms) == 0: return res_str res_str = f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}" return res_str # return "test" # Define the FAPM interface description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information. The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main).""" # iface = gr.Interface( # fn=generate_caption, # inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")], # outputs=gr.Textbox(label="Generated description"), # description=description # ) # # Launch the interface # iface.launch() css = """ #output { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown(description) with gr.Tab(label="Protein caption"): with gr.Row(): with gr.Column(): input_protein = gr.Textbox(type="text", label="Upload sequence") # model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large') prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)") submit_btn = gr.Button(value="Submit") with gr.Column(): output_text = gr.Textbox(label="Output Text") # train index 127, 266, 738, 1060 test index 4 gr.Examples( examples=[ ["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''], ["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''], ["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'], ['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'], ['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'], ['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'], ], inputs=[input_protein, prompt], outputs=[output_text], fn=generate_caption, cache_examples=True, label='Try examples' ) submit_btn.click(generate_caption, [input_protein, prompt], [output_text]) demo.launch(debug=True)