wenkai commited on
Commit
f3ed046
1 Parent(s): 3daa625

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -9,6 +9,9 @@ import spaces
9
  import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
 
 
 
12
 
13
 
14
  # Load the model
@@ -20,6 +23,21 @@ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
20
  model_esm.to('cuda')
21
  model_esm.eval()
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @spaces.GPU
25
  def generate_caption(protein, prompt):
@@ -106,7 +124,16 @@ def generate_caption(protein, prompt):
106
  prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
107
  repetition_penalty=1.0)
108
 
109
- return prediction
 
 
 
 
 
 
 
 
 
110
  # return "test"
111
 
112
 
 
9
  import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
+ from data.evaluate_data.utils import Ontology
13
+ import difflib
14
+ import re
15
 
16
 
17
  # Load the model
 
23
  model_esm.to('cuda')
24
  model_esm.eval()
25
 
26
+ godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
27
+ go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
28
+ go_des.columns = ['id', 'text']
29
+ go_des = go_des.dropna()
30
+ go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
31
+ go_obo_set = set(go_des['id'].tolist())
32
+ go_des['text'] = go_des['text'].apply(lambda x: x.lower())
33
+ GO_dict = dict(zip(go_des['text'], go_des['id']))
34
+ Func_dict = dict(zip(go_des['id'], go_des['text']))
35
+
36
+ # terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
37
+ terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
38
+ choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
39
+ choices = {x.lower(): x for x in choices_mf}
40
+
41
 
42
  @spaces.GPU
43
  def generate_caption(protein, prompt):
 
124
  prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
125
  repetition_penalty=1.0)
126
 
127
+ x = prediction[0]
128
+ x = [eval(i) for i in x.split('; ')]
129
+ pred_terms = []
130
+ for i in x:
131
+ txt = i[0]
132
+ prob = i[1]
133
+ sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
134
+ if len(sim_list) > 0:
135
+ pred_terms.append((sim_list[0], prob))
136
+ return str(pred_terms)
137
  # return "test"
138
 
139