|
""" |
|
This file tests whether the model predictions for MHNfs match the predictions made on |
|
the JKU development server (varified model, server conda env with spec. packages ...) |
|
""" |
|
|
|
|
|
|
|
import pytest |
|
import torch |
|
import pandas as pd |
|
from prediction_pipeline import ActivityPredictor |
|
|
|
|
|
|
|
|
|
class TestActivityPredictor: |
|
|
|
def test_mhnfs_prediction(self, model_input_query, model_input_support_actives, |
|
model_input_support_inactives, model_predictions): |
|
|
|
|
|
predictor = ActivityPredictor() |
|
|
|
|
|
support_actives_size = torch.tensor(model_input_support_actives.shape[1]) |
|
support_inactives_size = torch.tensor(model_input_support_inactives.shape[1]) |
|
|
|
|
|
predictions = predictor.model( |
|
model_input_query, |
|
model_input_support_actives, |
|
model_input_support_inactives, |
|
support_actives_size, |
|
support_inactives_size |
|
).detach() |
|
|
|
|
|
assert torch.allclose(predictions, model_predictions, atol=0.01, rtol=0.) |
|
|
|
def test_query_mol_return(self): |
|
|
|
|
|
support_actives_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1"] |
|
support_inactives_smiles = ["CCN(CC)C(=S)SSC(=S)N(CC)CCCCC"] |
|
|
|
|
|
predictor = ActivityPredictor() |
|
|
|
|
|
query_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", |
|
"CCN(CC)C(=S)SSC(=S)N(CC)CC"] |
|
|
|
_ = predictor.predict(query_smiles, support_actives_smiles, |
|
support_inactives_smiles) |
|
query_output = predictor._return_query_mols_as_list() |
|
assert query_output == query_smiles |
|
|
|
|
|
query_smiles_str = ("CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1," |
|
"CCN(CC)C(=S)SSC(=S)N(CC)CC") |
|
_ = predictor.predict(query_smiles_str, support_actives_smiles, |
|
support_inactives_smiles) |
|
query_output = predictor._return_query_mols_as_list() |
|
assert query_output == query_smiles |
|
|
|
|
|
query_smiles_series = pd.DataFrame({"smiles": |
|
["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"]}) |
|
_ = predictor.predict(query_smiles_series, support_actives_smiles, |
|
support_inactives_smiles) |
|
query_output = predictor._return_query_mols_as_list() |
|
assert query_output == query_smiles |
|
|
|
|
|
predictor.query_molecules = None |
|
with pytest.raises(ValueError): |
|
predictor._return_query_mols_as_list() |
|
|
|
|
|
predictor.query_molecules = 123 |
|
with pytest.raises(TypeError): |
|
predictor._return_query_mols_as_list() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|