File size: 3,416 Bytes
cf004a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
"""
This file tests whether the model predictions for MHNfs match the predictions made on
the JKU development server (varified model, server conda env with spec. packages ...)
"""
#---------------------------------------------------------------------------------------
# Dependencies
import pytest
import torch
import pandas as pd
from prediction_pipeline import ActivityPredictor
#---------------------------------------------------------------------------------------
# Define tests
class TestActivityPredictor:
def test_mhnfs_prediction(self, model_input_query, model_input_support_actives,
model_input_support_inactives, model_predictions):
# Load model
predictor = ActivityPredictor()
# Define additional inputs to model - i.e. support set sizes
support_actives_size = torch.tensor(model_input_support_actives.shape[1])
support_inactives_size = torch.tensor(model_input_support_inactives.shape[1])
# Make predictions
predictions = predictor.model(
model_input_query,
model_input_support_actives,
model_input_support_inactives,
support_actives_size,
support_inactives_size
).detach()
# Compare predictions
assert torch.allclose(predictions, model_predictions, atol=0.01, rtol=0.)
def test_query_mol_return(self):
# Support set
support_actives_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1"]
support_inactives_smiles = ["CCN(CC)C(=S)SSC(=S)N(CC)CCCCC"]
# Load activity predictor
predictor = ActivityPredictor()
# Check 1: Query mols given as a list
query_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1",
"CCN(CC)C(=S)SSC(=S)N(CC)CC"]
_ = predictor.predict(query_smiles, support_actives_smiles,
support_inactives_smiles)
query_output = predictor._return_query_mols_as_list()
assert query_output == query_smiles
# Check 2: Query mols given as a string
query_smiles_str = ("CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1,"
"CCN(CC)C(=S)SSC(=S)N(CC)CC")
_ = predictor.predict(query_smiles_str, support_actives_smiles,
support_inactives_smiles)
query_output = predictor._return_query_mols_as_list()
assert query_output == query_smiles
# Check 3: Query mols given as a pd.Series
query_smiles_series = pd.DataFrame({"smiles":
["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"]})
_ = predictor.predict(query_smiles_series, support_actives_smiles,
support_inactives_smiles)
query_output = predictor._return_query_mols_as_list()
assert query_output == query_smiles
# Check 4: Query molecules storage is None
predictor.query_molecules = None
with pytest.raises(ValueError):
predictor._return_query_mols_as_list()
# Check 5: Other data types
predictor.query_molecules = 123 # any other data type
with pytest.raises(TypeError):
predictor._return_query_mols_as_list()
|