File size: 3,416 Bytes
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
This file tests whether the model predictions for MHNfs match the predictions made on
the JKU development server (varified model, server conda env with spec. packages ...)
"""

#---------------------------------------------------------------------------------------
# Dependencies
import pytest
import torch
import pandas as pd
from prediction_pipeline import ActivityPredictor

#---------------------------------------------------------------------------------------
# Define tests

class TestActivityPredictor:
    
    def test_mhnfs_prediction(self, model_input_query, model_input_support_actives,
                              model_input_support_inactives, model_predictions):
        
        # Load model
        predictor = ActivityPredictor()
        
        # Define additional inputs to model - i.e. support set sizes
        support_actives_size = torch.tensor(model_input_support_actives.shape[1])
        support_inactives_size = torch.tensor(model_input_support_inactives.shape[1])
        
        # Make predictions
        predictions = predictor.model(
            model_input_query,
            model_input_support_actives,
            model_input_support_inactives,
            support_actives_size,
            support_inactives_size
        ).detach()
        
        # Compare predictions
        assert torch.allclose(predictions, model_predictions, atol=0.01, rtol=0.)
    
    def test_query_mol_return(self):
        
        # Support set
        support_actives_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1"]
        support_inactives_smiles = ["CCN(CC)C(=S)SSC(=S)N(CC)CCCCC"]
        
        # Load activity predictor
        predictor = ActivityPredictor()
        
        # Check 1: Query mols given as a list
        query_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1",
                        "CCN(CC)C(=S)SSC(=S)N(CC)CC"]
        
        _ = predictor.predict(query_smiles, support_actives_smiles,
                              support_inactives_smiles)
        query_output = predictor._return_query_mols_as_list()   
        assert query_output == query_smiles
        
        # Check 2: Query mols given as a string
        query_smiles_str = ("CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1,"
                            "CCN(CC)C(=S)SSC(=S)N(CC)CC")
        _ = predictor.predict(query_smiles_str, support_actives_smiles,
                              support_inactives_smiles)
        query_output = predictor._return_query_mols_as_list()
        assert query_output == query_smiles
        
        # Check 3: Query mols given as a pd.Series
        query_smiles_series = pd.DataFrame({"smiles":
            ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"]})
        _ = predictor.predict(query_smiles_series, support_actives_smiles,
                              support_inactives_smiles)
        query_output = predictor._return_query_mols_as_list()
        assert query_output == query_smiles
        
        # Check 4: Query molecules storage is None
        predictor.query_molecules = None
        with pytest.raises(ValueError):
            predictor._return_query_mols_as_list()
        
        # Check 5: Other data types
        predictor.query_molecules = 123 # any other data type
        with pytest.raises(TypeError):
            predictor._return_query_mols_as_list()