wenkai commited on
Commit
eb615db
1 Parent(s): aad9fe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -40
app.py CHANGED
@@ -32,12 +32,12 @@ def get_model(type='Molecule Function'):
32
  models = {
33
  'Molecule Function': get_model('Molecule Function'),
34
  'Biological Process': get_model('Biological Process'),
35
- 'Cellar Component': get_model('Cellar Component'),
36
  }
37
 
38
 
39
  # Load the mistral model
40
- mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
41
 
42
  # Load ESM2 model
43
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
@@ -54,14 +54,23 @@ go_des['text'] = go_des['text'].apply(lambda x: x.lower())
54
  GO_dict = dict(zip(go_des['text'], go_des['id']))
55
  Func_dict = dict(zip(go_des['id'], go_des['text']))
56
 
57
- # terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
58
  terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
59
  choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
60
- choices = {x.lower(): x for x in choices_mf}
61
-
 
 
 
 
 
 
 
 
 
 
62
 
63
  @spaces.GPU
64
- def generate_caption(model_id, protein, prompt):
65
  # Process the image and the prompt
66
  # with open('/home/user/app/example.fasta', 'w') as f:
67
  # f.write('>{}\n'.format("protein_name"))
@@ -144,36 +153,40 @@ def generate_caption(model_id, protein, prompt):
144
  'text_input': ['none'],
145
  'prompt': [prompt]}
146
 
147
- model = models[model_id]
148
- # Generate the output
149
- prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
150
- repetition_penalty=1.0)
151
-
152
- x = prediction[0]
153
- x = [eval(i) for i in x.split('; ')]
154
- pred_terms = []
155
- temp = []
156
- for i in x:
157
- txt = i[0]
158
- prob = i[1]
159
- sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
160
- if len(sim_list) > 0:
161
- t_standard = sim_list[0]
162
- if t_standard not in temp:
163
- pred_terms.append(t_standard+f'({prob})')
164
- temp.append(t_standard)
 
 
 
165
  if prompt == 'none':
166
  res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
167
  else:
168
  res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
169
- if len(pred_terms) == 0:
170
  return res_str
171
- if model_id == 'Molecule Function':
172
- res_str = f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}"
173
- elif model_id == 'Biological Process':
174
- res_str = f"Based on the given amino acid sequence, it is likely involved in the {', '.join(pred_terms)}"
175
- elif model_id == 'Cellar Component':
176
- res_str = f"Based on the given amino acid sequence, it's subcellular localization is within the {', '.join(pred_terms)}"
 
177
  return res_str
178
  # return "test"
179
 
@@ -205,7 +218,6 @@ with gr.Blocks(css=css) as demo:
205
  with gr.Tab(label="Protein caption"):
206
  with gr.Row():
207
  with gr.Column():
208
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='Molecule Function')
209
  input_protein = gr.Textbox(type="text", label="Upload sequence")
210
  prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
211
  submit_btn = gr.Button(value="Submit")
@@ -214,20 +226,20 @@ with gr.Blocks(css=css) as demo:
214
  # O14813 train index 127, 266, 738, 1060 test index 4
215
  gr.Examples(
216
  examples=[
217
- ["Molecule Function", "MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
218
- ["Molecule Function", "MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
219
- ["Molecule Function", "MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
220
- ["Molecule Function", 'MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
221
- ["Molecule Function", 'MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
222
- ["Molecule Function", 'MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
223
  ],
224
- inputs=[model_selector, input_protein, prompt],
225
  outputs=[output_text],
226
  fn=generate_caption,
227
  cache_examples=True,
228
  label='Try examples'
229
  )
230
- submit_btn.click(generate_caption, [model_selector, input_protein, prompt], [output_text])
231
 
232
  demo.launch(debug=True)
233
 
 
32
  models = {
33
  'Molecule Function': get_model('Molecule Function'),
34
  'Biological Process': get_model('Biological Process'),
35
+ 'Cellular Component': get_model('Cellar Component'),
36
  }
37
 
38
 
39
  # Load the mistral model
40
+ mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16).to('cuda')
41
 
42
  # Load ESM2 model
43
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
 
54
  GO_dict = dict(zip(go_des['text'], go_des['id']))
55
  Func_dict = dict(zip(go_des['id'], go_des['text']))
56
 
 
57
  terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
58
  choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
59
+ choices_mf = {x.lower(): x for x in choices_mf}
60
+ terms_bp = pd.read_pickle('data/terms/bp_terms.pkl')
61
+ choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))]
62
+ choices_bp = {x.lower(): x for x in choices_bp}
63
+ terms_cc = pd.read_pickle('data/terms/cc_terms.pkl')
64
+ choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))]
65
+ choices_cc = {x.lower(): x for x in choices_cc}
66
+ choices = {
67
+ 'Molecule Function': choices_mf,
68
+ 'Biological Process': choices_bp,
69
+ 'Cellular Component': choices_cc,
70
+ }
71
 
72
  @spaces.GPU
73
+ def generate_caption(protein, prompt):
74
  # Process the image and the prompt
75
  # with open('/home/user/app/example.fasta', 'w') as f:
76
  # f.write('>{}\n'.format("protein_name"))
 
153
  'text_input': ['none'],
154
  'prompt': [prompt]}
155
 
156
+ union_pred_terms = []
157
+ for model_id in models.keys():
158
+ model = models[model_id]
159
+ # Generate the output
160
+ prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
161
+ repetition_penalty=1.0)
162
+ x = prediction[0]
163
+ x = [eval(i) for i in x.split('; ')]
164
+ pred_terms = []
165
+ temp = []
166
+ for i in x:
167
+ txt = i[0]
168
+ prob = i[1]
169
+ sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9)
170
+ if len(sim_list) > 0:
171
+ t_standard = sim_list[0]
172
+ if t_standard not in temp:
173
+ pred_terms.append(t_standard+f'({prob})')
174
+ temp.append(t_standard)
175
+ union_pred_terms.append(pred_terms)
176
+
177
  if prompt == 'none':
178
  res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
179
  else:
180
  res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
181
+ if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0:
182
  return res_str
183
+ res_str = ''
184
+ if len(union_pred_terms[0]) != 0:
185
+ res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}. "
186
+ if len(union_pred_terms[1]) != 0:
187
+ res_str += f"It is likely involved in the {', '.join(pred_terms)}. "
188
+ if len(union_pred_terms[2]) != 0:
189
+ res_str += f"It's subcellular localization is within the {', '.join(pred_terms)}."
190
  return res_str
191
  # return "test"
192
 
 
218
  with gr.Tab(label="Protein caption"):
219
  with gr.Row():
220
  with gr.Column():
 
221
  input_protein = gr.Textbox(type="text", label="Upload sequence")
222
  prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
223
  submit_btn = gr.Button(value="Submit")
 
226
  # O14813 train index 127, 266, 738, 1060 test index 4
227
  gr.Examples(
228
  examples=[
229
+ ["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
230
+ ["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
231
+ ["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
232
+ ['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
233
+ ['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
234
+ ['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
235
  ],
236
+ inputs=[input_protein, prompt],
237
  outputs=[output_text],
238
  fn=generate_caption,
239
  cache_examples=True,
240
  label='Try examples'
241
  )
242
+ submit_btn.click(generate_caption, [input_protein, prompt], [output_text])
243
 
244
  demo.launch(debug=True)
245