FAPM_demo / app.py
wenkai's picture
Update app.py
83df3cd verified
raw
history blame
No virus
4.99 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
import torch.nn.functional as F
import copy
import torch
import random
import numpy as np
from esm import pretrained, FastaBatchedDataset
def get_model(model_id):
a, b = pretrained.load_model_and_alphabet(model_id.split('/')[1])
a.to('cuda').eval()
return (a, b)
models = {
'facebook/esm2_t36_3B_UR50D': get_model('facebook/esm2_t36_3B_UR50D'),
}
DESCRIPTION = "Esm2 embedding"
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
@spaces.GPU
def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
model_esm, alphabet = models[model_id]
protein_name = 'protein_name'
protein_seq = protein
include = 'per_tok'
repr_layers = [36]
truncation_seq_length = 1024
toks_per_batch = 4096
print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1: truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
torch.save(esm_emb, 'example.pt')
return gr.File.update(value="example.pt", visible=True)
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Esm2 embedding generation"):
with gr.Row():
with gr.Column():
input_protein = gr.Textbox(type="text", label="Upload sequence")
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
submit_btn = gr.Button(value="Submit")
with gr.Column():
button = gr.Button("Export")
pt = gr.File(interactive=False, visible=False)
# gr.Examples(
# examples=[
# ["image1.jpg", 'Object Detection'],
# ],
# inputs=[input_img, task_prompt],
# outputs=[output_text, output_img],
# fn=process_image,
# cache_examples=True,
# label='Try examples'
# )
button.click(run_example, [input_protein, model_selector], pt)
demo.launch(debug=True)