FAPM_demo / app.py
wenkai's picture
Update app.py
cdf31f1 verified
raw
history blame
5.7 kB
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
# from transformers import EsmTokenizer, EsmModel
# Load the model
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
model.load_checkpoint("model/checkpoint_mf2.pth")
model.to('cuda')
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.to('cuda')
model_esm.eval()
# tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
# model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
# model_esm.to('cuda')
# model_esm.eval()
@spaces.GPU
def generate_caption(protein, prompt):
# Process the image and the prompt
# with open('/home/user/app/example.fasta', 'w') as f:
# f.write('>{}\n'.format("protein_name"))
# f.write('{}\n'.format(protein.strip()))
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
# model=model_esm, alphabet=alphabet,
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
protein_name='protein_name'
protein_seq=protein
include='per_tok'
repr_layers=[36]
truncation_seq_length=1024
toks_per_batch=4096
print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
logits = out["logits"].to(device="cpu")
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1 : truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1 : truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
print("esm embedding processed")
samples = {'name': ['protein_name'],
'image': torch.unsqueeze(esm_emb, dim=0),
'text_input': ['none'],
'prompt': [prompt]}
# Generate the output
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
return prediction
# return "test"
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
iface = gr.Interface(
fn=generate_caption,
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
outputs=gr.Textbox(label="Generated description"),
description=description
)
# Launch the interface
iface.launch()