Anton Bushuiev commited on
Commit
53691e3
·
1 Parent(s): ce43754

Basic ZeroGPU setup

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -27,6 +27,7 @@ import logging
27
  from pathlib import Path
28
  from functools import partial
29
 
 
30
  import gradio as gr
31
  import torch
32
  import numpy as np
@@ -43,7 +44,8 @@ from ppiref.extraction import PPIExtractor
43
  from ppiref.utils.ppi import PPIPath
44
  from ppiref.utils.residue import Residue
45
  from ppiformer.tasks.node import DDGPPIformer
46
- from ppiformer.utils.api import download_weights, predict_ddg
 
47
  from ppiformer.utils.torch import fill_diagonal
48
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
49
 
@@ -57,6 +59,11 @@ logging.basicConfig(
57
  random.seed(0)
58
 
59
 
 
 
 
 
 
60
  def process_inputs(inputs, temp_dir):
61
  pdb_code, pdb_path, partners, muts, muts_path = inputs
62
 
@@ -479,11 +486,15 @@ with app:
479
  # Download weights from Zenodo
480
  download_weights()
481
 
 
 
 
 
482
  # Load models
483
  models = [
484
  DDGPPIformer.load_from_checkpoint(
485
  PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
486
- map_location=torch.device('cpu')
487
  ).eval()
488
  for i in range(3)
489
  ]
 
27
  from pathlib import Path
28
  from functools import partial
29
 
30
+ import spaces
31
  import gradio as gr
32
  import torch
33
  import numpy as np
 
44
  from ppiref.utils.ppi import PPIPath
45
  from ppiref.utils.residue import Residue
46
  from ppiformer.tasks.node import DDGPPIformer
47
+ from ppiformer.utils.api import download_weights
48
+ from ppiformer.utils.api import predict_ddg as predict_ddg_
49
  from ppiformer.utils.torch import fill_diagonal
50
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
51
 
 
59
  random.seed(0)
60
 
61
 
62
+ @spaces.GPU
63
+ def predict_ddg(*args, **kwargs):
64
+ return predict_ddg_(*args, **kwargs)
65
+
66
+
67
  def process_inputs(inputs, temp_dir):
68
  pdb_code, pdb_path, partners, muts, muts_path = inputs
69
 
 
486
  # Download weights from Zenodo
487
  download_weights()
488
 
489
+ # Determine device
490
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
491
+ device = torch.device('cuda')
492
+
493
  # Load models
494
  models = [
495
  DDGPPIformer.load_from_checkpoint(
496
  PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
497
+ map_location=device
498
  ).eval()
499
  for i in range(3)
500
  ]