import pytest | |
import os | |
import sys | |
import logging | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from protac_degradation_predictor import PROTAC_Model, PROTAC_Predictor | |
import torch | |
def test_protac_model(): | |
model = PROTAC_Model(hidden_dim=128) | |
assert model.hidden_dim == 128 | |
assert model.smiles_emb_dim == 256 | |
assert model.poi_emb_dim == 1024 | |
assert model.e3_emb_dim == 1024 | |
assert model.cell_emb_dim == 768 | |
assert model.batch_size == 128 | |
assert model.learning_rate == 0.001 | |
assert model.dropout == 0.2 | |
assert model.join_embeddings == 'sum' | |
assert model.train_dataset is None | |
assert model.val_dataset is None | |
assert model.test_dataset is None | |
assert model.disabled_embeddings == [] | |
assert model.apply_scaling == True | |
def test_protac_predictor(): | |
predictor = PROTAC_Predictor(hidden_dim=128) | |
assert predictor.hidden_dim == 128 | |
assert predictor.smiles_emb_dim == 256 | |
assert predictor.poi_emb_dim == 1024 | |
assert predictor.e3_emb_dim == 1024 | |
assert predictor.cell_emb_dim == 768 | |
assert predictor.join_embeddings == 'sum' | |
assert predictor.disabled_embeddings == [] | |
def test_load_model(caplog): | |
# caplog.set_level(logging.WARNING) | |
# model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt' | |
# model = PROTAC_Model.load_from_checkpoint( | |
# model_filename, | |
# map_location=torch.device("cpu") if not torch.cuda.is_available() else None, | |
# ) | |
# assert model.hidden_dim == 768 | |
# assert model.smiles_emb_dim == 224 | |
# assert model.poi_emb_dim == 1024 | |
# assert model.e3_emb_dim == 1024 | |
# assert model.cell_emb_dim == 768 | |
# assert model.batch_size == 8 | |
# assert model.learning_rate == 1.843233973932415e-05 | |
# assert model.dropout == 0.11257777663560328 | |
# assert model.join_embeddings == 'concat' | |
# assert model.disabled_embeddings == [] | |
# assert model.apply_scaling == True | |
# print(model.scalers) | |
pass | |
def test_checkpoint_file(): | |
# model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt' | |
# checkpoint = torch.load( | |
# model_filename, | |
# map_location=torch.device("cpu") if not torch.cuda.is_available() else None, | |
# ) | |
# print(checkpoint.keys()) | |
# print(checkpoint["hyper_parameters"]) | |
# print([k for k, v in checkpoint["state_dict"].items()]) | |
# import pickle | |
# print(pickle.loads(checkpoint['scalers'])) | |
pass | |
pytest.main() | |