File size: 5,272 Bytes
08b9eb6 dd9c8e6 9b993cf 0137aa6 08b9eb6 c8e59d5 08b9eb6 9cc264c 08b9eb6 dd9c8e6 0b7981d 3660015 0b7981d 3660015 0b7981d 3660015 0f66ac3 0b7981d 9cc264c 0b7981d c8e59d5 9cc264c c8e59d5 0b7981d 9cc264c 0b7981d 61cedea 08b9eb6 61cedea dd9c8e6 08b9eb6 c8e59d5 08b9eb6 c8e59d5 08b9eb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
# Load the model
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
model.load_checkpoint("model/checkpoint_mf2.pth")
model.to('cuda')
# model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
# model_esm.to('cuda')
# model_esm.eval()
@spaces.GPU
def generate_caption(protein, prompt):
# Process the image and the prompt
# with open('/home/user/app/example.fasta', 'w') as f:
# f.write('>{}\n'.format("protein_name"))
# f.write('{}\n'.format(protein.strip()))
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
# model=model_esm, alphabet=alphabet,
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
protein_name='protein_name'
protein_seq=protein
include='per_tok'
repr_layers=[36]
truncation_seq_length=1024
toks_per_batch=4096
print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=model.alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model.model_esm.num_layers + 1) <= i <= model.model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model.model_esm.num_layers + 1) % (model.model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model.model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
logits = out["logits"].to(device="cpu")
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1 : truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1 : truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
print("esm embedding processed")
samples = {'name': ['protein_name'],
'image': torch.unsqueeze(esm_emb, dim=0),
'text_input': ['none'],
'prompt': [prompt]}
# Generate the output
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
return prediction
# return "test"
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
iface = gr.Interface(
fn=generate_caption,
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
outputs=gr.Textbox(label="Generated description"),
description=description
)
# Launch the interface
iface.launch()
|