mhnfs / src /tests /test_prediction_pipeline_model_preds.py
Tschoui's picture
move project from private to public space
cf004a6
raw
history blame contribute delete
No virus
3.42 kB
"""
This file tests whether the model predictions for MHNfs match the predictions made on
the JKU development server (varified model, server conda env with spec. packages ...)
"""
#---------------------------------------------------------------------------------------
# Dependencies
import pytest
import torch
import pandas as pd
from prediction_pipeline import ActivityPredictor
#---------------------------------------------------------------------------------------
# Define tests
class TestActivityPredictor:
def test_mhnfs_prediction(self, model_input_query, model_input_support_actives,
model_input_support_inactives, model_predictions):
# Load model
predictor = ActivityPredictor()
# Define additional inputs to model - i.e. support set sizes
support_actives_size = torch.tensor(model_input_support_actives.shape[1])
support_inactives_size = torch.tensor(model_input_support_inactives.shape[1])
# Make predictions
predictions = predictor.model(
model_input_query,
model_input_support_actives,
model_input_support_inactives,
support_actives_size,
support_inactives_size
).detach()
# Compare predictions
assert torch.allclose(predictions, model_predictions, atol=0.01, rtol=0.)
def test_query_mol_return(self):
# Support set
support_actives_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1"]
support_inactives_smiles = ["CCN(CC)C(=S)SSC(=S)N(CC)CCCCC"]
# Load activity predictor
predictor = ActivityPredictor()
# Check 1: Query mols given as a list
query_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1",
"CCN(CC)C(=S)SSC(=S)N(CC)CC"]
_ = predictor.predict(query_smiles, support_actives_smiles,
support_inactives_smiles)
query_output = predictor._return_query_mols_as_list()
assert query_output == query_smiles
# Check 2: Query mols given as a string
query_smiles_str = ("CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1,"
"CCN(CC)C(=S)SSC(=S)N(CC)CC")
_ = predictor.predict(query_smiles_str, support_actives_smiles,
support_inactives_smiles)
query_output = predictor._return_query_mols_as_list()
assert query_output == query_smiles
# Check 3: Query mols given as a pd.Series
query_smiles_series = pd.DataFrame({"smiles":
["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"]})
_ = predictor.predict(query_smiles_series, support_actives_smiles,
support_inactives_smiles)
query_output = predictor._return_query_mols_as_list()
assert query_output == query_smiles
# Check 4: Query molecules storage is None
predictor.query_molecules = None
with pytest.raises(ValueError):
predictor._return_query_mols_as_list()
# Check 5: Other data types
predictor.query_molecules = 123 # any other data type
with pytest.raises(TypeError):
predictor._return_query_mols_as_list()