File size: 6,034 Bytes
08b9eb6
 
 
 
 
 
 
 
 
dd9c8e6
9b993cf
1a0324b
cdf31f1
3705c34
0137aa6
08b9eb6
1a0324b
 
 
 
 
 
 
 
08b9eb6
cdf31f1
 
 
 
08b9eb6
 
 
 
dd9c8e6
 
 
 
1a0324b
 
0b7981d
1a0324b
 
 
 
 
 
 
3660015
0b7981d
3660015
0b7981d
3660015
1a0324b
 
 
 
 
0b7981d
d376f39
0b7981d
 
 
c8e59d5
d376f39
 
c8e59d5
0b7981d
 
 
 
 
 
 
d376f39
0b7981d
 
 
 
 
 
 
 
 
 
 
 
 
1a0324b
0b7981d
 
 
 
1a0324b
0b7981d
 
 
 
 
 
 
 
 
3705c34
77b966b
3705c34
77b966b
3705c34
cdf31f1
61cedea
08b9eb6
61cedea
dd9c8e6
08b9eb6
 
 
1a0324b
 
 
 
 
 
08b9eb6
1a0324b
 
08b9eb6
c8e59d5
 
08b9eb6
1a0324b
08b9eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset

# from transformers import EsmTokenizer, EsmModel


# Load the model
# model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
# model.load_checkpoint("model/checkpoint_mf2.pth")
# model.to('cuda')

# model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
# model_esm.to('cuda')
# model_esm.eval()


# tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
# model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
# model_esm.to('cuda')
# model_esm.eval()

@spaces.GPU
def generate_caption(protein, prompt):
    # Process the image and the prompt
    # with open('/home/user/app/example.fasta', 'w') as f:
    #     f.write('>{}\n'.format("protein_name"))
    #     f.write('{}\n'.format(protein.strip()))
    # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
    # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
    #                    model=model_esm, alphabet=alphabet,
    #                    include='per_tok', repr_layers=[36], truncation_seq_length=1024)

    protein_name = 'protein_name'
    protein_seq = protein
    include = 'per_tok'
    repr_layers = [36]
    truncation_seq_length = 1024
    toks_per_batch = 4096
    print("start")
    dataset = FastaBatchedDataset([protein_name], [protein_seq])
    print("dataset prepared")
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    print("batches prepared")

    model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
    model_esm.to('cuda')
    model_esm.eval()

    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
    )
    print(f"Read sequences")
    return_contacts = "contacts" in include

    assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
    repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")
            for i, label in enumerate(labels):
                result = {"label": label}
                truncate_len = min(truncation_seq_length, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in include:
                    result["representations"] = {
                        layer: t[i, 1: truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in include:
                    result["mean_representations"] = {
                        layer: t[i, 1: truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
            esm_emb = result['representations'][36]
    '''
    inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
    with torch.no_grad():
        outputs = model_esm(**inputs)
    esm_emb = outputs.last_hidden_state.detach()[0]
    '''
    print("esm embedding generated")
    esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
    print("esm embedding processed")
    samples = {'name': ['protein_name'],
               'image': torch.unsqueeze(esm_emb, dim=0),
               'text_input': ['none'],
               'prompt': [prompt]}

    del model_esm

    model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
    model.load_checkpoint("model/checkpoint_mf2.pth")
    model.to('cuda')
    # Generate the output
    prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
                                repetition_penalty=1.0)

    return prediction
    # return "test"


# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.

The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""

iface = gr.Interface(
    fn=generate_caption,
    inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
    outputs=gr.Textbox(label="Generated description"),
    description=description
)

# Launch the interface
iface.launch()