File size: 4,994 Bytes
08b9eb6
72b0e49
 
 
 
 
 
 
 
9b993cf
1a0324b
3705c34
aa13001
 
83df3cd
aa13001
 
72b0e49
aa13001
72b0e49
0137aa6
1a0324b
 
72b0e49
08b9eb6
72b0e49
 
08b9eb6
1a0324b
72b0e49
 
 
1a0324b
 
 
 
 
 
3660015
0b7981d
3660015
0b7981d
3660015
1a0324b
0b7981d
d376f39
0b7981d
 
 
c8e59d5
d376f39
 
c8e59d5
0b7981d
 
 
 
 
 
 
d376f39
0b7981d
 
 
 
 
 
 
 
 
 
 
 
1a0324b
0b7981d
 
 
 
1a0324b
0b7981d
 
 
 
 
 
 
 
 
3705c34
77b966b
3705c34
77b966b
3705c34
cdf31f1
61cedea
72b0e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
import torch.nn.functional as F
import copy
import torch

import random
import numpy as np
from esm import pretrained, FastaBatchedDataset


def get_model(model_id):
    a, b = pretrained.load_model_and_alphabet(model_id.split('/')[1])
    a.to('cuda').eval()
    return (a, b)

models = {
    'facebook/esm2_t36_3B_UR50D': get_model('facebook/esm2_t36_3B_UR50D'),
    }



DESCRIPTION = "Esm2 embedding"

colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
            'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']


@spaces.GPU
def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
    model_esm, alphabet = models[model_id]
    protein_name = 'protein_name'
    protein_seq = protein
    include = 'per_tok'
    repr_layers = [36]
    truncation_seq_length = 1024
    toks_per_batch = 4096
    print("start")
    dataset = FastaBatchedDataset([protein_name], [protein_seq])
    print("dataset prepared")
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    print("batches prepared")

    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
    )
    print(f"Read sequences")
    return_contacts = "contacts" in include

    assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
    repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")
            for i, label in enumerate(labels):
                result = {"label": label}
                truncate_len = min(truncation_seq_length, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in include:
                    result["representations"] = {
                        layer: t[i, 1: truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in include:
                    result["mean_representations"] = {
                        layer: t[i, 1: truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
            esm_emb = result['representations'][36]
    '''
    inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
    with torch.no_grad():
        outputs = model_esm(**inputs)
    esm_emb = outputs.last_hidden_state.detach()[0]
    '''
    print("esm embedding generated")
    esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
    torch.save(esm_emb, 'example.pt')
    return gr.File.update(value="example.pt", visible=True)

css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Tab(label="Esm2 embedding generation"):
        with gr.Row():
            with gr.Column():
                input_protein = gr.Textbox(type="text", label="Upload sequence")
                model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                button = gr.Button("Export")
                pt = gr.File(interactive=False, visible=False)
        # gr.Examples(
        #     examples=[
        #         ["image1.jpg", 'Object Detection'],
        #     ],
        #     inputs=[input_img, task_prompt],
        #     outputs=[output_text, output_img],
        #     fn=process_image,
        #     cache_examples=True,
        #     label='Try examples'
        # )

        button.click(run_example, [input_protein, model_selector], pt)

demo.launch(debug=True)