AlienChen commited on
Commit
c14a243
·
verified ·
1 Parent(s): f13b58d

Update models/peptide_classifiers.py

Browse files
Files changed (1) hide show
  1. models/peptide_classifiers.py +3 -3
models/peptide_classifiers.py CHANGED
@@ -149,7 +149,7 @@ class MotifModel(nn.Module):
149
 
150
  class HemolysisModel:
151
  def __init__(self, device):
152
- self.predictor = xgb.Booster(model_file='../classifier_ckpt/wt_hemolysis.json')
153
 
154
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
155
  self.model.eval()
@@ -206,7 +206,7 @@ class MLPClassifier(nn.Module):
206
 
207
  class NonfoulingModel:
208
  def __init__(self, device):
209
- ckpt = torch.load('../classifier_ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
210
  best_params = ckpt["best_params"]
211
  self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
212
  self.predictor.load_state_dict(ckpt["state_dict"])
@@ -346,7 +346,7 @@ class HalfLifeModel:
346
  def __init__(
347
  self,
348
  device,
349
- ckpt_path = "../classifier_ckpt/wt_halflife.pt",
350
  ):
351
  self.device = device
352
 
 
149
 
150
  class HemolysisModel:
151
  def __init__(self, device):
152
+ self.predictor = xgb.Booster(model_file='./classifier_ckpt/wt_hemolysis.json')
153
 
154
  self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
155
  self.model.eval()
 
206
 
207
  class NonfoulingModel:
208
  def __init__(self, device):
209
+ ckpt = torch.load('./classifier_ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
210
  best_params = ckpt["best_params"]
211
  self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
212
  self.predictor.load_state_dict(ckpt["state_dict"])
 
346
  def __init__(
347
  self,
348
  device,
349
+ ckpt_path = "./classifier_ckpt/wt_halflife.pt",
350
  ):
351
  self.device = device
352