wenkai commited on
Commit
72b0e49
1 Parent(s): d3edc5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -70
app.py CHANGED
@@ -1,44 +1,41 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import pandas as pd
5
- import torch.nn.functional as F
6
- from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
7
- from lavis.models.base_model import FAPMConfig
8
- import spaces
9
  import gradio as gr
10
- from esm_scripts.extract import run_demo
 
 
 
 
 
 
 
 
 
 
 
 
11
  from esm import pretrained, FastaBatchedDataset
12
 
13
- # from transformers import EsmTokenizer, EsmModel
14
 
 
 
 
15
 
16
- # Load the model
17
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
- model.load_checkpoint("model/checkpoint_mf2.pth")
19
- # model.to('cuda')
 
 
20
 
21
- model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
22
- # model_esm.to('cuda')
23
- model_esm.eval()
24
 
 
25
 
26
- # tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
27
- # model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
28
- # model_esm.to('cuda')
29
- # model_esm.eval()
30
 
31
- @spaces.GPU
32
- def generate_caption(protein, prompt):
33
- # Process the image and the prompt
34
- # with open('/home/user/app/example.fasta', 'w') as f:
35
- # f.write('>{}\n'.format("protein_name"))
36
- # f.write('{}\n'.format(protein.strip()))
37
- # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
38
- # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
39
- # model=model_esm, alphabet=alphabet,
40
- # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
41
 
 
 
 
42
  protein_name = 'protein_name'
43
  protein_seq = protein
44
  include = 'per_tok'
@@ -51,8 +48,6 @@ def generate_caption(protein, prompt):
51
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
52
  print("batches prepared")
53
 
54
- model_esm.to('cuda')
55
-
56
  data_loader = torch.utils.data.DataLoader(
57
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
58
  )
@@ -70,7 +65,6 @@ def generate_caption(protein, prompt):
70
  if torch.cuda.is_available():
71
  toks = toks.to(device="cuda", non_blocking=True)
72
  out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
73
- logits = out["logits"].to(device="cpu")
74
  representations = {
75
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
76
  }
@@ -105,39 +99,40 @@ def generate_caption(protein, prompt):
105
  esm_emb = outputs.last_hidden_state.detach()[0]
106
  '''
107
  print("esm embedding generated")
108
- esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
109
- print("esm embedding processed")
110
- samples = {'name': ['protein_name'],
111
- 'image': torch.unsqueeze(esm_emb, dim=0),
112
- 'text_input': ['none'],
113
- 'prompt': [prompt]}
114
-
115
- del model_esm
116
-
117
- model.to('cuda')
118
- # Generate the output
119
- prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
120
- repetition_penalty=1.0)
121
-
122
- return prediction
123
- # return "test"
124
-
125
-
126
- # Define the FAPM interface
127
- description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
128
-
129
- The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
130
-
131
- iface = gr.Interface(
132
- fn=generate_caption,
133
- inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
134
- outputs=gr.Textbox(label="Generated description"),
135
- description=description
136
- )
137
-
138
- # Launch the interface
139
- iface.launch()
140
-
141
-
142
-
143
-
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ import spaces
4
+ import torch.nn.functional as F
5
+ import requests
6
+ import copy
7
+ import torch
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import io
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+
13
+ import random
14
+ import numpy as np
15
  from esm import pretrained, FastaBatchedDataset
16
 
 
17
 
18
+ models = {
19
+ 'facebook/esm2_t36_3B_UR50D': pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D').to("cuda").eval(),
20
+ }
21
 
22
+ processors = {
23
+ 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
24
+ 'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
25
+ 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
26
+ 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
27
+ }
28
 
 
 
 
29
 
30
+ DESCRIPTION = "Esm2 embedding"
31
 
32
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
33
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
 
 
34
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ @spaces.GPU
37
+ def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
38
+ model_esm, alphabet = models[model_id]
39
  protein_name = 'protein_name'
40
  protein_seq = protein
41
  include = 'per_tok'
 
48
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
49
  print("batches prepared")
50
 
 
 
51
  data_loader = torch.utils.data.DataLoader(
52
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
53
  )
 
65
  if torch.cuda.is_available():
66
  toks = toks.to(device="cuda", non_blocking=True)
67
  out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
 
68
  representations = {
69
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
70
  }
 
99
  esm_emb = outputs.last_hidden_state.detach()[0]
100
  '''
101
  print("esm embedding generated")
102
+ esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
103
+ torch.save(esm_emb, 'example.pt')
104
+ return gr.File.update(value="example.pt", visible=True)
105
+
106
+ css = """
107
+ #output {
108
+ height: 500px;
109
+ overflow: auto;
110
+ border: 1px solid #ccc;
111
+ }
112
+ """
113
+
114
+ with gr.Blocks(css=css) as demo:
115
+ gr.Markdown(DESCRIPTION)
116
+ with gr.Tab(label="Esm2 embedding generation"):
117
+ with gr.Row():
118
+ with gr.Column():
119
+ input_protein = gr.Textbox(type="text", label="Upload sequence")
120
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
121
+ submit_btn = gr.Button(value="Submit")
122
+ with gr.Column():
123
+ button = gr.Button("Export")
124
+ pt = gr.File(interactive=False, visible=False)
125
+ # gr.Examples(
126
+ # examples=[
127
+ # ["image1.jpg", 'Object Detection'],
128
+ # ],
129
+ # inputs=[input_img, task_prompt],
130
+ # outputs=[output_text, output_img],
131
+ # fn=process_image,
132
+ # cache_examples=True,
133
+ # label='Try examples'
134
+ # )
135
+
136
+ button.click(run_example, [input_protein, model_selector], pt)
137
+
138
+ demo.launch(debug=True)