|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import pandas as pd
|
|
import torch.nn.functional as F
|
|
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
|
|
from lavis.models.base_model import FAPMConfig
|
|
import spaces
|
|
import gradio as gr
|
|
|
|
|
|
|
|
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
|
model.load_checkpoint("model/checkpoint_mf2.pth")
|
|
model.to('cuda')
|
|
|
|
|
|
@spaces.GPU
|
|
def generate_caption(protein, prompt):
|
|
|
|
with open('data/fasta/example.fasta', 'w') as f:
|
|
f.write('>{}\n'.format("protein_name"))
|
|
f.write('{}\n'.format(protein.strip()))
|
|
os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D data/fasta/example.fasta data/emb_esm2_3b --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
|
|
esm_emb = torch.load("data/emb_esm2_3b/protein_name.pt")['representations'][36]
|
|
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
|
samples = {'name': ['test_protein'],
|
|
'image': torch.unsqueeze(esm_emb, dim=0),
|
|
'text_input': ['none'],
|
|
'prompt': [prompt]}
|
|
|
|
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
|
|
|
|
return prediction
|
|
|
|
|
|
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
|
|
|
|
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
|
|
|
|
iface = gr.Interface(
|
|
fn=generate_caption,
|
|
inputs=[gr.Textbox(type="pil", label="Upload sequence"), gr.Textbox(label="Prompt", value="taxonomy prompt")],
|
|
outputs=gr.Textbox(label="Generated description"),
|
|
description=description
|
|
)
|
|
|
|
|
|
iface.launch()
|
|
|
|
|
|
|
|
|
|
|