AtomBiow commited on
Commit
eef7def
1 Parent(s): 3308cfa

working local gradio server

Browse files
Files changed (2) hide show
  1. api_prediction.py +1 -1
  2. inference.py +23 -19
api_prediction.py CHANGED
@@ -143,7 +143,7 @@ class AptaTransPipeline_Dist(AptaTransPipeline):
143
  print('Predicting the Aptamer-Protein Interaction')
144
  try:
145
  print("loading the best model for api!")
146
- self.model.load_state_dict(torch.load('./models/test_lr=1e-06_batch_size=16_epochs=30.pt', map_location=self.device))
147
  except:
148
  print('there is no best model file.')
149
  print('You need to train the model for predicting API!')
 
143
  print('Predicting the Aptamer-Protein Interaction')
144
  try:
145
  print("loading the best model for api!")
146
+ self.model.load_state_dict(torch.load('./models/model.pt', map_location=self.device))
147
  except:
148
  print('there is no best model file.')
149
  print('You need to train the model for predicting API!')
inference.py CHANGED
@@ -1,13 +1,8 @@
1
  from api_prediction import AptaTransPipeline_Dist
2
  import argparse
 
3
 
4
  def infer(prot, apta):
5
-
6
- print('Input protein: ', prot)
7
- print('Input aptamer: ', apta)
8
-
9
- print('Initializing model.')
10
-
11
  pipeline = AptaTransPipeline_Dist(
12
  lr=1e-6,
13
  weight_decay=None,
@@ -25,20 +20,29 @@ def infer(prot, apta):
25
  load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells
26
  device='cuda',
27
  seed=1004)
 
28
  scores = pipeline.inference([apta], [prot], [0])
29
- print('Your predicted score is: ', scores[0])
 
 
 
 
 
 
 
 
 
 
30
 
31
- def main():
32
- parser = argparse.ArgumentParser(
33
- prog='API Pipeline Inference',
34
- description='From a protein and RNA aptamer sequence, predict the binding score.')
35
- parser.add_argument('-p', '--prot')
36
- parser.add_argument('-a', '--apta')
37
 
38
- args = parser.parse_args()
39
- protein = str(args.prot)
40
- apta = str(args.apta)
41
- infer(protein, apta)
 
 
 
 
42
 
43
- if __name__ == "__main__":
44
- main()
 
1
  from api_prediction import AptaTransPipeline_Dist
2
  import argparse
3
+ import gradio as gr
4
 
5
  def infer(prot, apta):
 
 
 
 
 
 
6
  pipeline = AptaTransPipeline_Dist(
7
  lr=1e-6,
8
  weight_decay=None,
 
20
  load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells
21
  device='cuda',
22
  seed=1004)
23
+
24
  scores = pipeline.inference([apta], [prot], [0])
25
+ return scores[0]
26
+
27
+ def comparison(protein, aptamers):
28
+ aptamers = aptamers.split('\n')
29
+ pairs = [[protein, aptamer] for aptamer in aptamers]
30
+ print(pairs)
31
+ scores = []
32
+
33
+ for pair in pairs:
34
+ score = infer(pair[0], pair[1])
35
+ scores.append(score)
36
 
37
+ return scores
 
 
 
 
 
38
 
39
+ iface = gr.Interface(
40
+ fn=comparison,
41
+ inputs=[
42
+ gr.Textbox(lines=2, placeholder="Protein"),
43
+ gr.Textbox(lines=10,placeholder="Aptamers (1 per line)")
44
+ ],
45
+ outputs=gr.Textbox()
46
+ )
47
 
48
+ iface.launch()