import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM import spaces import torch.nn.functional as F import copy import torch import random import numpy as np from esm import pretrained, FastaBatchedDataset def get_model(model_id): a, b = pretrained.load_model_and_alphabet(model_id.split('/')[1]) a.to('cuda').eval() return (a, b) models = { 'facebook/esm2_t36_3B_UR50D': get_model('facebook/esm2_t36_3B_UR50D'), } DESCRIPTION = "Esm2 embedding" colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red', 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue'] @spaces.GPU def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'): model_esm, alphabet = models[model_id] protein_name = 'protein_name' protein_seq = protein include = 'per_tok' repr_layers = [36] truncation_seq_length = 1024 toks_per_batch = 4096 print("start") dataset = FastaBatchedDataset([protein_name], [protein_seq]) print("dataset prepared") batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) print("batches prepared") data_loader = torch.utils.data.DataLoader( dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches ) print(f"Read sequences") return_contacts = "contacts" in include assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers) repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers] with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(data_loader): print( f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" ) if torch.cuda.is_available(): toks = toks.to(device="cuda", non_blocking=True) out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts) representations = { layer: t.to(device="cpu") for layer, t in out["representations"].items() } if return_contacts: contacts = out["contacts"].to(device="cpu") for i, label in enumerate(labels): result = {"label": label} truncate_len = min(truncation_seq_length, len(strs[i])) # Call clone on tensors to ensure tensors are not views into a larger representation # See https://github.com/pytorch/pytorch/issues/1995 if "per_tok" in include: result["representations"] = { layer: t[i, 1: truncate_len + 1].clone() for layer, t in representations.items() } if "mean" in include: result["mean_representations"] = { layer: t[i, 1: truncate_len + 1].mean(0).clone() for layer, t in representations.items() } if "bos" in include: result["bos_representations"] = { layer: t[i, 0].clone() for layer, t in representations.items() } if return_contacts: result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() esm_emb = result['representations'][36] ''' inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda') with torch.no_grad(): outputs = model_esm(**inputs) esm_emb = outputs.last_hidden_state.detach()[0] ''' print("esm embedding generated") esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t() torch.save(esm_emb, 'example.pt') return gr.File.update(value="example.pt", visible=True) css = """ #output { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown(DESCRIPTION) with gr.Tab(label="Esm2 embedding generation"): with gr.Row(): with gr.Column(): input_protein = gr.Textbox(type="text", label="Upload sequence") model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large') submit_btn = gr.Button(value="Submit") with gr.Column(): button = gr.Button("Export") pt = gr.File(interactive=False, visible=False) # gr.Examples( # examples=[ # ["image1.jpg", 'Object Detection'], # ], # inputs=[input_img, task_prompt], # outputs=[output_text, output_img], # fn=process_image, # cache_examples=True, # label='Try examples' # ) button.click(run_example, [input_protein, model_selector], pt) demo.launch(debug=True)