wenkai commited on
Commit
2b26389
1 Parent(s): 83df3cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -116
app.py CHANGED
@@ -1,91 +1,28 @@
1
- import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM
3
- import spaces
4
- import torch.nn.functional as F
5
- import copy
6
  import torch
7
-
8
- import random
9
- import numpy as np
 
 
 
 
 
10
  from esm import pretrained, FastaBatchedDataset
11
 
12
-
13
- def get_model(model_id):
14
- a, b = pretrained.load_model_and_alphabet(model_id.split('/')[1])
15
- a.to('cuda').eval()
16
- return (a, b)
17
-
18
- models = {
19
- 'facebook/esm2_t36_3B_UR50D': get_model('facebook/esm2_t36_3B_UR50D'),
20
- }
21
 
22
 
23
-
24
- DESCRIPTION = "Esm2 embedding"
25
-
26
- colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
27
- 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
28
 
29
 
30
  @spaces.GPU
31
- def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
32
- model_esm, alphabet = models[model_id]
33
- protein_name = 'protein_name'
34
- protein_seq = protein
35
- include = 'per_tok'
36
- repr_layers = [36]
37
- truncation_seq_length = 1024
38
- toks_per_batch = 4096
39
- print("start")
40
- dataset = FastaBatchedDataset([protein_name], [protein_seq])
41
- print("dataset prepared")
42
- batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
43
- print("batches prepared")
44
-
45
- data_loader = torch.utils.data.DataLoader(
46
- dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
47
- )
48
- print(f"Read sequences")
49
- return_contacts = "contacts" in include
50
-
51
- assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
52
- repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
53
 
54
- with torch.no_grad():
55
- for batch_idx, (labels, strs, toks) in enumerate(data_loader):
56
- print(
57
- f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
58
- )
59
- if torch.cuda.is_available():
60
- toks = toks.to(device="cuda", non_blocking=True)
61
- out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
62
- representations = {
63
- layer: t.to(device="cpu") for layer, t in out["representations"].items()
64
- }
65
- if return_contacts:
66
- contacts = out["contacts"].to(device="cpu")
67
- for i, label in enumerate(labels):
68
- result = {"label": label}
69
- truncate_len = min(truncation_seq_length, len(strs[i]))
70
- # Call clone on tensors to ensure tensors are not views into a larger representation
71
- # See https://github.com/pytorch/pytorch/issues/1995
72
- if "per_tok" in include:
73
- result["representations"] = {
74
- layer: t[i, 1: truncate_len + 1].clone()
75
- for layer, t in representations.items()
76
- }
77
- if "mean" in include:
78
- result["mean_representations"] = {
79
- layer: t[i, 1: truncate_len + 1].mean(0).clone()
80
- for layer, t in representations.items()
81
- }
82
- if "bos" in include:
83
- result["bos_representations"] = {
84
- layer: t[i, 0].clone() for layer, t in representations.items()
85
- }
86
- if return_contacts:
87
- result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
88
- esm_emb = result['representations'][36]
89
  '''
90
  inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
91
  with torch.no_grad():
@@ -93,40 +30,36 @@ def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
93
  esm_emb = outputs.last_hidden_state.detach()[0]
94
  '''
95
  print("esm embedding generated")
96
- esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
97
- torch.save(esm_emb, 'example.pt')
98
- return gr.File.update(value="example.pt", visible=True)
99
-
100
- css = """
101
- #output {
102
- height: 500px;
103
- overflow: auto;
104
- border: 1px solid #ccc;
105
- }
106
- """
107
-
108
- with gr.Blocks(css=css) as demo:
109
- gr.Markdown(DESCRIPTION)
110
- with gr.Tab(label="Esm2 embedding generation"):
111
- with gr.Row():
112
- with gr.Column():
113
- input_protein = gr.Textbox(type="text", label="Upload sequence")
114
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
115
- submit_btn = gr.Button(value="Submit")
116
- with gr.Column():
117
- button = gr.Button("Export")
118
- pt = gr.File(interactive=False, visible=False)
119
- # gr.Examples(
120
- # examples=[
121
- # ["image1.jpg", 'Object Detection'],
122
- # ],
123
- # inputs=[input_img, task_prompt],
124
- # outputs=[output_text, output_img],
125
- # fn=process_image,
126
- # cache_examples=True,
127
- # label='Try examples'
128
- # )
129
-
130
- button.click(run_example, [input_protein, model_selector], pt)
131
-
132
- demo.launch(debug=True)
 
1
+ import os
 
 
 
 
2
  import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ import torch.nn.functional as F
6
+ from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
7
+ from lavis.models.base_model import FAPMConfig
8
+ import spaces
9
+ import gradio as gr
10
+ from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
 
13
+ # from transformers import EsmTokenizer, EsmModel
 
 
 
 
 
 
 
 
14
 
15
 
16
+ # Load the model
17
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
+ model.load_checkpoint("model/checkpoint_mf2.pth")
19
+ model.to('cuda')
 
20
 
21
 
22
  @spaces.GPU
23
+ def generate_caption(protein, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  '''
27
  inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
28
  with torch.no_grad():
 
30
  esm_emb = outputs.last_hidden_state.detach()[0]
31
  '''
32
  print("esm embedding generated")
33
+ esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
34
+ print("esm embedding processed")
35
+ samples = {'name': ['protein_name'],
36
+ 'image': torch.unsqueeze(esm_emb, dim=0),
37
+ 'text_input': ['none'],
38
+ 'prompt': [prompt]}
39
+
40
+ # Generate the output
41
+ prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
42
+ repetition_penalty=1.0)
43
+
44
+ return prediction
45
+ # return "test"
46
+
47
+
48
+ # Define the FAPM interface
49
+ description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
50
+
51
+ The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
52
+
53
+ iface = gr.Interface(
54
+ fn=generate_caption,
55
+ inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
56
+ outputs=gr.Textbox(label="Generated description"),
57
+ description=description
58
+ )
59
+
60
+ # Launch the interface
61
+ iface.launch()
62
+
63
+
64
+
65
+