wenkai commited on
Commit
892748f
1 Parent(s): 71ea7d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -19
app.py CHANGED
@@ -7,7 +7,7 @@ from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
7
  from lavis.models.base_model import FAPMConfig
8
  import spaces
9
  import gradio as gr
10
- from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
  from data.evaluate_data.utils import Ontology
13
  import difflib
@@ -15,9 +15,29 @@ import re
15
 
16
 
17
  # Load the model
18
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
19
- model.load_checkpoint("model/checkpoint_mf2.pth")
20
- model.to('cuda')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
23
  model_esm.to('cuda')
@@ -40,7 +60,7 @@ choices = {x.lower(): x for x in choices_mf}
40
 
41
 
42
  @spaces.GPU
43
- def generate_caption(protein, prompt):
44
  # Process the image and the prompt
45
  # with open('/home/user/app/example.fasta', 'w') as f:
46
  # f.write('>{}\n'.format("protein_name"))
@@ -123,6 +143,7 @@ def generate_caption(protein, prompt):
123
  'text_input': ['none'],
124
  'prompt': [prompt]}
125
 
 
126
  # Generate the output
127
  prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
128
  repetition_penalty=1.0)
@@ -151,7 +172,6 @@ description = """Quick demonstration of the FAPM model for protein function pred
151
 
152
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
153
 
154
-
155
  # iface = gr.Interface(
156
  # fn=generate_caption,
157
  # inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
@@ -161,7 +181,6 @@ The model used in this app is available at [Hugging Face Model Hub](https://hugg
161
  # # Launch the interface
162
  # iface.launch()
163
 
164
-
165
  css = """
166
  #output {
167
  height: 500px;
@@ -175,34 +194,29 @@ with gr.Blocks(css=css) as demo:
175
  with gr.Tab(label="Protein caption"):
176
  with gr.Row():
177
  with gr.Column():
 
178
  input_protein = gr.Textbox(type="text", label="Upload sequence")
179
- # model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
180
- prompt = gr.Textbox(type="text", label="Taxonomy Prompt")
181
  submit_btn = gr.Button(value="Submit")
182
  with gr.Column():
183
  output_text = gr.Textbox(label="Output Text")
184
- # train index 99, 127, 266, 738, 1060 test index 4
185
  gr.Examples(
186
  examples=[
187
- ["MKTLLLTLVVVTIVCLDLGNSLKCYVSREGKTQTCPEGEKLCEKYAVSYFHDGRWRYRYECTSACHRGPYNVCCSTDLCNK", 'Micrurus'],
188
- ["MSSSAGSGHQPSQSRAIPTRTVAISDAAQLPHDYCTTPGGTLFSTTPGGTRIIYDRKFLLDRRNSPMAQTPPCHLPNIPGVTSPGTLIEDSKVEVNNLNNLNNHDRKHAVGDDAQFEMDI", 'Homo'],
189
- ["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", 'Sophophora'],
190
- ["MAARGAMLRYLRVNVNPTIQNPRECVLPFSILLRRFSEEVRGSFLDKSEVTDRVLSVVKNFQKVDPSKVTPKANFQNDLGLDSLDSVEVVMALEEEFGFEIPDNEADKIQSIDLAVDFIASHPQAK", 'Arabidopsis'],
191
  ["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
192
  ['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
193
  ['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
194
- ['MQMYKLTAGTTGYHTLLTRTQAEHMLSLWGDKYSIDDCTPSNPIYSPSRYTKLELVYMAANATA', 'Bacteriophage'],
195
- ['MSITAMDAKLQRILEESTCFGIGHDPNVKECKMCDVREQCKAKTQGMNVPTPTRKKPEDVAPAKEKPTTKKTTAKKSTAKEEKKETAPKAKETKAKPKSKPKKAKAPENPNLPNFKEMSFEELVELAKERNVEWKDYNSPNITRMRLIMALKASY', 'Bacteriophage'],
196
  ['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
197
  ],
198
- inputs=[input_protein, prompt],
199
  outputs=[output_text],
200
  fn=generate_caption,
201
  cache_examples=True,
202
  label='Try examples'
203
  )
204
-
205
- submit_btn.click(generate_caption, [input_protein, prompt], [output_text])
206
 
207
  demo.launch(debug=True)
208
 
 
7
  from lavis.models.base_model import FAPMConfig
8
  import spaces
9
  import gradio as gr
10
+ # from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
  from data.evaluate_data.utils import Ontology
13
  import difflib
 
15
 
16
 
17
  # Load the model
18
+ # model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
19
+ # model.load_checkpoint("model/checkpoint_mf2.pth")
20
+ # model.to('cuda')
21
+
22
+ def get_model(type='Molecule Function'):
23
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
24
+ if type == 'Molecule Function':
25
+ model.load_checkpoint("model/checkpoint_mf2.pth")
26
+ model.to('cuda')
27
+ elif type == 'Biological Process':
28
+ model.load_checkpoint("model/checkpoint_bp1.pth")
29
+ model.to('cuda')
30
+ elif type == 'Cellar Component':
31
+ model.load_checkpoint("model/checkpoint_cc2.pth")
32
+ model.to('cuda')
33
+
34
+
35
+ models = {
36
+ 'Molecule Function': get_model('Molecule Function'),
37
+ 'Biological Process': get_model('Biological Process'),
38
+ 'Cellar Component': get_model('Cellar Component'),
39
+ }
40
+
41
 
42
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
43
  model_esm.to('cuda')
 
60
 
61
 
62
  @spaces.GPU
63
+ def generate_caption(protein, prompt, model_id):
64
  # Process the image and the prompt
65
  # with open('/home/user/app/example.fasta', 'w') as f:
66
  # f.write('>{}\n'.format("protein_name"))
 
143
  'text_input': ['none'],
144
  'prompt': [prompt]}
145
 
146
+ model = models[model_id]
147
  # Generate the output
148
  prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
149
  repetition_penalty=1.0)
 
172
 
173
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
174
 
 
175
  # iface = gr.Interface(
176
  # fn=generate_caption,
177
  # inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
 
181
  # # Launch the interface
182
  # iface.launch()
183
 
 
184
  css = """
185
  #output {
186
  height: 500px;
 
194
  with gr.Tab(label="Protein caption"):
195
  with gr.Row():
196
  with gr.Column():
197
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='Molecule Function')
198
  input_protein = gr.Textbox(type="text", label="Upload sequence")
199
+ prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
 
200
  submit_btn = gr.Button(value="Submit")
201
  with gr.Column():
202
  output_text = gr.Textbox(label="Output Text")
203
+ # O14813 train index 127, 266, 738, 1060 test index 4
204
  gr.Examples(
205
  examples=[
206
+ ["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
207
+ ["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
 
 
208
  ["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
209
  ['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
210
  ['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
 
 
211
  ['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
212
  ],
213
+ inputs=[input_protein, prompt, model_selector],
214
  outputs=[output_text],
215
  fn=generate_caption,
216
  cache_examples=True,
217
  label='Try examples'
218
  )
219
+ submit_btn.click(generate_caption, [input_protein, prompt, model_selector], [output_text])
 
220
 
221
  demo.launch(debug=True)
222