PROTAC-Degradation-Predictor / tests /test_pytorch_model.py
ribesstefano's picture
Fixed some tests and added XGBoost to the API
1171189
raw
history blame
2.54 kB
import pytest
import os
import sys
import logging
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from protac_degradation_predictor import PROTAC_Model, PROTAC_Predictor
import torch
def test_protac_model():
model = PROTAC_Model(hidden_dim=128)
assert model.hidden_dim == 128
assert model.smiles_emb_dim == 256
assert model.poi_emb_dim == 1024
assert model.e3_emb_dim == 1024
assert model.cell_emb_dim == 768
assert model.batch_size == 128
assert model.learning_rate == 0.001
assert model.dropout == 0.2
assert model.join_embeddings == 'sum'
assert model.train_dataset is None
assert model.val_dataset is None
assert model.test_dataset is None
assert model.disabled_embeddings == []
assert model.apply_scaling == True
def test_protac_predictor():
predictor = PROTAC_Predictor(hidden_dim=128)
assert predictor.hidden_dim == 128
assert predictor.smiles_emb_dim == 256
assert predictor.poi_emb_dim == 1024
assert predictor.e3_emb_dim == 1024
assert predictor.cell_emb_dim == 768
assert predictor.join_embeddings == 'sum'
assert predictor.disabled_embeddings == []
def test_load_model(caplog):
# caplog.set_level(logging.WARNING)
# model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
# model = PROTAC_Model.load_from_checkpoint(
# model_filename,
# map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
# )
# assert model.hidden_dim == 768
# assert model.smiles_emb_dim == 224
# assert model.poi_emb_dim == 1024
# assert model.e3_emb_dim == 1024
# assert model.cell_emb_dim == 768
# assert model.batch_size == 8
# assert model.learning_rate == 1.843233973932415e-05
# assert model.dropout == 0.11257777663560328
# assert model.join_embeddings == 'concat'
# assert model.disabled_embeddings == []
# assert model.apply_scaling == True
# print(model.scalers)
pass
def test_checkpoint_file():
# model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
# checkpoint = torch.load(
# model_filename,
# map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
# )
# print(checkpoint.keys())
# print(checkpoint["hyper_parameters"])
# print([k for k, v in checkpoint["state_dict"].items()])
# import pickle
# print(pickle.loads(checkpoint['scalers']))
pass
pytest.main()