mschuh commited on
Commit
005026b
·
verified ·
1 Parent(s): c07eff0

Update for ZeroGPU

Browse files
Files changed (1) hide show
  1. model/model.py +4 -2
model/model.py CHANGED
@@ -27,7 +27,6 @@ from model.barlow_twins import BarlowTwins
27
  from utils.sequence import uniprot2sequence, encode_sequences
28
 
29
 
30
- @spaces.GPU
31
  class DTIModel:
32
  def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
33
  self.bt_model = BarlowTwins()
@@ -68,7 +67,8 @@ class DTIModel:
68
  def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False):
69
  morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
70
  return np.array(morgan)
71
-
 
72
  def _encode_sequence(self, sequence: str):
73
  # Clear torch cache
74
  torch.cuda.empty_cache()
@@ -88,10 +88,12 @@ class DTIModel:
88
  print(e)
89
  return None
90
 
 
91
  def _encode_sequence_mult(self, sequences: List[str]):
92
  seq = [self._encode_sequence(sequence) for sequence in sequences]
93
  return np.array(seq)
94
 
 
95
  def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
96
  if drug_emb.shape[0] < target_emb.shape[0]:
97
  drug_emb = np.tile(drug_emb, (len(target_emb), 1))
 
27
  from utils.sequence import uniprot2sequence, encode_sequences
28
 
29
 
 
30
  class DTIModel:
31
  def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
32
  self.bt_model = BarlowTwins()
 
67
  def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False):
68
  morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
69
  return np.array(morgan)
70
+
71
+ @spaces.GPU
72
  def _encode_sequence(self, sequence: str):
73
  # Clear torch cache
74
  torch.cuda.empty_cache()
 
88
  print(e)
89
  return None
90
 
91
+ @spaces.GPU
92
  def _encode_sequence_mult(self, sequences: List[str]):
93
  seq = [self._encode_sequence(sequence) for sequence in sequences]
94
  return np.array(seq)
95
 
96
+ @spaces.GPU
97
  def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
98
  if drug_emb.shape[0] < target_emb.shape[0]:
99
  drug_emb = np.tile(drug_emb, (len(target_emb), 1))