File size: 2,542 Bytes
de956c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171189
de956c8
 
 
1171189
de956c8
 
1171189
de956c8
 
 
 
1171189
de956c8
 
 
 
1171189
de956c8
 
 
1171189
de956c8
 
 
1171189
de956c8
1171189
15216c3
1171189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de956c8
 
 
1171189
 
 
 
 
 
 
 
 
15216c3
1171189
 
de956c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import pytest
import os
import sys
import logging

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from protac_degradation_predictor import PROTAC_Model, PROTAC_Predictor

import torch


def test_protac_model():
    model = PROTAC_Model(hidden_dim=128)
    assert model.hidden_dim == 128
    assert model.smiles_emb_dim == 256
    assert model.poi_emb_dim == 1024
    assert model.e3_emb_dim == 1024
    assert model.cell_emb_dim == 768
    assert model.batch_size == 128
    assert model.learning_rate == 0.001
    assert model.dropout == 0.2
    assert model.join_embeddings == 'sum'
    assert model.train_dataset is None
    assert model.val_dataset is None
    assert model.test_dataset is None
    assert model.disabled_embeddings == []
    assert model.apply_scaling == True

def test_protac_predictor():
    predictor = PROTAC_Predictor(hidden_dim=128)
    assert predictor.hidden_dim == 128
    assert predictor.smiles_emb_dim == 256
    assert predictor.poi_emb_dim == 1024
    assert predictor.e3_emb_dim == 1024
    assert predictor.cell_emb_dim == 768
    assert predictor.join_embeddings == 'sum'
    assert predictor.disabled_embeddings == []

def test_load_model(caplog):
    # caplog.set_level(logging.WARNING)

    # model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'

    # model = PROTAC_Model.load_from_checkpoint(
    #     model_filename,
    #     map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
    # )
    # assert model.hidden_dim == 768
    # assert model.smiles_emb_dim == 224
    # assert model.poi_emb_dim == 1024
    # assert model.e3_emb_dim == 1024
    # assert model.cell_emb_dim == 768
    # assert model.batch_size == 8
    # assert model.learning_rate == 1.843233973932415e-05
    # assert model.dropout == 0.11257777663560328
    # assert model.join_embeddings == 'concat'
    # assert model.disabled_embeddings == []
    # assert model.apply_scaling == True
    # print(model.scalers)
    pass


def test_checkpoint_file():
    # model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
    # checkpoint = torch.load(
    #     model_filename,
    #     map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
    # )
    # print(checkpoint.keys())
    # print(checkpoint["hyper_parameters"])
    # print([k for k, v in checkpoint["state_dict"].items()])
    # import pickle

    # print(pickle.loads(checkpoint['scalers']))
    pass

pytest.main()