|
import os |
|
import torch |
|
import torch.nn as nn |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral |
|
from lavis.models.base_model import FAPMConfig |
|
import spaces |
|
import gradio as gr |
|
from esm_scripts.extract import run_demo |
|
from esm import pretrained, FastaBatchedDataset |
|
|
|
|
|
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b') |
|
model.load_checkpoint("model/checkpoint_mf2.pth") |
|
model.to('cuda') |
|
|
|
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D') |
|
model_esm.to('cuda') |
|
model_esm.eval() |
|
|
|
@spaces.GPU |
|
def generate_caption(protein, prompt): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protein_name='protein_name' |
|
protein_seq=protein |
|
include='per_tok' |
|
repr_layers=[36] |
|
truncation_seq_length=1024 |
|
toks_per_batch=4096 |
|
print("start") |
|
dataset = FastaBatchedDataset([protein_name], [protein_seq]) |
|
print("dataset prepared") |
|
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) |
|
print("batches prepared") |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches |
|
) |
|
print(f"Read sequences") |
|
return_contacts = "contacts" in include |
|
|
|
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers) |
|
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers] |
|
|
|
with torch.no_grad(): |
|
for batch_idx, (labels, strs, toks) in enumerate(data_loader): |
|
print( |
|
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" |
|
) |
|
if torch.cuda.is_available(): |
|
toks = toks.to(device="cuda", non_blocking=True) |
|
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts) |
|
logits = out["logits"].to(device="cpu") |
|
representations = { |
|
layer: t.to(device="cpu") for layer, t in out["representations"].items() |
|
} |
|
if return_contacts: |
|
contacts = out["contacts"].to(device="cpu") |
|
for i, label in enumerate(labels): |
|
result = {"label": label} |
|
truncate_len = min(truncation_seq_length, len(strs[i])) |
|
|
|
|
|
if "per_tok" in include: |
|
result["representations"] = { |
|
layer: t[i, 1 : truncate_len + 1].clone() |
|
for layer, t in representations.items() |
|
} |
|
if "mean" in include: |
|
result["mean_representations"] = { |
|
layer: t[i, 1 : truncate_len + 1].mean(0).clone() |
|
for layer, t in representations.items() |
|
} |
|
if "bos" in include: |
|
result["bos_representations"] = { |
|
layer: t[i, 0].clone() for layer, t in representations.items() |
|
} |
|
if return_contacts: |
|
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() |
|
esm_emb = result['representations'][36] |
|
|
|
print("esm embedding generated") |
|
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda') |
|
print("esm embedding processed") |
|
samples = {'name': ['protein_name'], |
|
'image': torch.unsqueeze(esm_emb, dim=0), |
|
'text_input': ['none'], |
|
'prompt': [prompt]} |
|
|
|
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0) |
|
|
|
return prediction |
|
|
|
|
|
|
|
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information. |
|
|
|
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main).""" |
|
|
|
iface = gr.Interface( |
|
fn=generate_caption, |
|
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")], |
|
outputs=gr.Textbox(label="Generated description"), |
|
description=description |
|
) |
|
|
|
|
|
iface.launch() |
|
|
|
|
|
|
|
|
|
|