Spaces:
Running
on
Zero
Running
on
Zero
Anton Bushuiev
commited on
Commit
·
53691e3
1
Parent(s):
ce43754
Basic ZeroGPU setup
Browse files
app.py
CHANGED
@@ -27,6 +27,7 @@ import logging
|
|
27 |
from pathlib import Path
|
28 |
from functools import partial
|
29 |
|
|
|
30 |
import gradio as gr
|
31 |
import torch
|
32 |
import numpy as np
|
@@ -43,7 +44,8 @@ from ppiref.extraction import PPIExtractor
|
|
43 |
from ppiref.utils.ppi import PPIPath
|
44 |
from ppiref.utils.residue import Residue
|
45 |
from ppiformer.tasks.node import DDGPPIformer
|
46 |
-
from ppiformer.utils.api import download_weights
|
|
|
47 |
from ppiformer.utils.torch import fill_diagonal
|
48 |
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
|
49 |
|
@@ -57,6 +59,11 @@ logging.basicConfig(
|
|
57 |
random.seed(0)
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
60 |
def process_inputs(inputs, temp_dir):
|
61 |
pdb_code, pdb_path, partners, muts, muts_path = inputs
|
62 |
|
@@ -479,11 +486,15 @@ with app:
|
|
479 |
# Download weights from Zenodo
|
480 |
download_weights()
|
481 |
|
|
|
|
|
|
|
|
|
482 |
# Load models
|
483 |
models = [
|
484 |
DDGPPIformer.load_from_checkpoint(
|
485 |
PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
|
486 |
-
map_location=
|
487 |
).eval()
|
488 |
for i in range(3)
|
489 |
]
|
|
|
27 |
from pathlib import Path
|
28 |
from functools import partial
|
29 |
|
30 |
+
import spaces
|
31 |
import gradio as gr
|
32 |
import torch
|
33 |
import numpy as np
|
|
|
44 |
from ppiref.utils.ppi import PPIPath
|
45 |
from ppiref.utils.residue import Residue
|
46 |
from ppiformer.tasks.node import DDGPPIformer
|
47 |
+
from ppiformer.utils.api import download_weights
|
48 |
+
from ppiformer.utils.api import predict_ddg as predict_ddg_
|
49 |
from ppiformer.utils.torch import fill_diagonal
|
50 |
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
|
51 |
|
|
|
59 |
random.seed(0)
|
60 |
|
61 |
|
62 |
+
@spaces.GPU
|
63 |
+
def predict_ddg(*args, **kwargs):
|
64 |
+
return predict_ddg_(*args, **kwargs)
|
65 |
+
|
66 |
+
|
67 |
def process_inputs(inputs, temp_dir):
|
68 |
pdb_code, pdb_path, partners, muts, muts_path = inputs
|
69 |
|
|
|
486 |
# Download weights from Zenodo
|
487 |
download_weights()
|
488 |
|
489 |
+
# Determine device
|
490 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
491 |
+
device = torch.device('cuda')
|
492 |
+
|
493 |
# Load models
|
494 |
models = [
|
495 |
DDGPPIformer.load_from_checkpoint(
|
496 |
PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
|
497 |
+
map_location=device
|
498 |
).eval()
|
499 |
for i in range(3)
|
500 |
]
|