Update models/peptide_classifiers.py
Browse files
models/peptide_classifiers.py
CHANGED
|
@@ -149,7 +149,7 @@ class MotifModel(nn.Module):
|
|
| 149 |
|
| 150 |
class HemolysisModel:
|
| 151 |
def __init__(self, device):
|
| 152 |
-
self.predictor = xgb.Booster(model_file='.
|
| 153 |
|
| 154 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 155 |
self.model.eval()
|
|
@@ -206,7 +206,7 @@ class MLPClassifier(nn.Module):
|
|
| 206 |
|
| 207 |
class NonfoulingModel:
|
| 208 |
def __init__(self, device):
|
| 209 |
-
ckpt = torch.load('.
|
| 210 |
best_params = ckpt["best_params"]
|
| 211 |
self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
|
| 212 |
self.predictor.load_state_dict(ckpt["state_dict"])
|
|
@@ -346,7 +346,7 @@ class HalfLifeModel:
|
|
| 346 |
def __init__(
|
| 347 |
self,
|
| 348 |
device,
|
| 349 |
-
ckpt_path = ".
|
| 350 |
):
|
| 351 |
self.device = device
|
| 352 |
|
|
|
|
| 149 |
|
| 150 |
class HemolysisModel:
|
| 151 |
def __init__(self, device):
|
| 152 |
+
self.predictor = xgb.Booster(model_file='./classifier_ckpt/wt_hemolysis.json')
|
| 153 |
|
| 154 |
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 155 |
self.model.eval()
|
|
|
|
| 206 |
|
| 207 |
class NonfoulingModel:
|
| 208 |
def __init__(self, device):
|
| 209 |
+
ckpt = torch.load('./classifier_ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
|
| 210 |
best_params = ckpt["best_params"]
|
| 211 |
self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
|
| 212 |
self.predictor.load_state_dict(ckpt["state_dict"])
|
|
|
|
| 346 |
def __init__(
|
| 347 |
self,
|
| 348 |
device,
|
| 349 |
+
ckpt_path = "./classifier_ckpt/wt_halflife.pt",
|
| 350 |
):
|
| 351 |
self.device = device
|
| 352 |
|