File size: 9,558 Bytes
2b26389 72b0e49 2b26389 9b993cf f3ed046 1a0324b 2b26389 08b9eb6 3daa625 f3ed046 1a0324b 72b0e49 2b26389 3daa625 c8e59d5 3daa625 3705c34 77b966b 3705c34 77b966b 3705c34 cdf31f1 61cedea 2b26389 2bc812b 2b26389 f3ed046 1167137 f3ed046 1167137 43b95d1 1167137 2b26389 c3846ee 2bc812b c3846ee 2bc812b c3846ee 43b95d1 2bc812b c3846ee 2b26389 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
from data.evaluate_data.utils import Ontology
import difflib
import re
# Load the model
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
model.load_checkpoint("model/checkpoint_mf2.pth")
model.to('cuda')
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.to('cuda')
model_esm.eval()
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
go_des.columns = ['id', 'text']
go_des = go_des.dropna()
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
go_obo_set = set(go_des['id'].tolist())
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
GO_dict = dict(zip(go_des['text'], go_des['id']))
Func_dict = dict(zip(go_des['id'], go_des['text']))
# terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
choices = {x.lower(): x for x in choices_mf}
@spaces.GPU
def generate_caption(protein, prompt):
# Process the image and the prompt
# with open('/home/user/app/example.fasta', 'w') as f:
# f.write('>{}\n'.format("protein_name"))
# f.write('{}\n'.format(protein.strip()))
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
# model=model_esm, alphabet=alphabet,
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
protein_name = 'protein_name'
protein_seq = protein
include = 'per_tok'
repr_layers = [36]
truncation_seq_length = 1024
toks_per_batch = 4096
print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1: truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
if prompt is None:
prompt = 'none'
else:
prompt = prompt.lower()
samples = {'name': ['protein_name'],
'image': torch.unsqueeze(esm_emb, dim=0),
'text_input': ['none'],
'prompt': [prompt]}
# Generate the output
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
repetition_penalty=1.0)
x = prediction[0]
x = [eval(i) for i in x.split('; ')]
pred_terms = []
temp = []
for i in x:
txt = i[0]
prob = i[1]
sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
if len(sim_list) > 0:
t_standard = sim_list[0]
if t_standard not in temp:
pred_terms.append(t_standard+f'({prob})')
temp.append(t_standard)
res_str = f"Based on the given amino acid sequence, the proteinappears to have a primary function of {', '.join(pred_terms)}"
return res_str
# return "test"
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
# iface = gr.Interface(
# fn=generate_caption,
# inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
# outputs=gr.Textbox(label="Generated description"),
# description=description
# )
# # Launch the interface
# iface.launch()
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(description)
with gr.Tab(label="Protein caption"):
with gr.Row():
with gr.Column():
input_protein = gr.Textbox(type="text", label="Upload sequence")
# model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
prompt = gr.Textbox(type="text", label="Taxonomy Prompt")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
# train index 99, 127, 266, 738, 1060 test index 4
gr.Examples(
examples=[
["MKTLLLTLVVVTIVCLDLGNSLKCYVSREGKTQTCPEGEKLCEKYAVSYFHDGRWRYRYECTSACHRGPYNVCCSTDLCNK", 'Micrurus'],
["MSSSAGSGHQPSQSRAIPTRTVAISDAAQLPHDYCTTPGGTLFSTTPGGTRIIYDRKFLLDRRNSPMAQTPPCHLPNIPGVTSPGTLIEDSKVEVNNLNNLNNHDRKHAVGDDAQFEMDI", 'Homo'],
["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", 'Sophophora'],
["MAARGAMLRYLRVNVNPTIQNPRECVLPFSILLRRFSEEVRGSFLDKSEVTDRVLSVVKNFQKVDPSKVTPKANFQNDLGLDSLDSVEVVMALEEEFGFEIPDNEADKIQSIDLAVDFIASHPQAK", 'Arabidopsis'],
["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
['MQMYKLTAGTTGYHTLLTRTQAEHMLSLWGDKYSIDDCTPSNPIYSPSRYTKLELVYMAANATA', 'Bacteriophage'],
['MSITAMDAKLQRILEESTCFGIGHDPNVKECKMCDVREQCKAKTQGMNVPTPTRKKPEDVAPAKEKPTTKKTTAKKSTAKEEKKETAPKAKETKAKPKSKPKKAKAPENPNLPNFKEMSFEELVELAKERNVEWKDYNSPNITRMRLIMALKASY', 'Bacteriophage'],
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
],
inputs=[input_protein, prompt],
outputs=[output_text],
fn=generate_caption,
cache_examples=True,
label='Try examples'
)
submit_btn.click(generate_caption, [input_protein, prompt], [output_text])
demo.launch(debug=True)
|