""" This file tests whether the model predictions for MHNfs match the predictions made on the JKU development server (varified model, server conda env with spec. packages ...) """ #--------------------------------------------------------------------------------------- # Dependencies import pytest import torch import pandas as pd from prediction_pipeline import ActivityPredictor #--------------------------------------------------------------------------------------- # Define tests class TestActivityPredictor: def test_mhnfs_prediction(self, model_input_query, model_input_support_actives, model_input_support_inactives, model_predictions): # Load model predictor = ActivityPredictor() # Define additional inputs to model - i.e. support set sizes support_actives_size = torch.tensor(model_input_support_actives.shape[1]) support_inactives_size = torch.tensor(model_input_support_inactives.shape[1]) # Make predictions predictions = predictor.model( model_input_query, model_input_support_actives, model_input_support_inactives, support_actives_size, support_inactives_size ).detach() # Compare predictions assert torch.allclose(predictions, model_predictions, atol=0.01, rtol=0.) def test_query_mol_return(self): # Support set support_actives_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1"] support_inactives_smiles = ["CCN(CC)C(=S)SSC(=S)N(CC)CCCCC"] # Load activity predictor predictor = ActivityPredictor() # Check 1: Query mols given as a list query_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"] _ = predictor.predict(query_smiles, support_actives_smiles, support_inactives_smiles) query_output = predictor._return_query_mols_as_list() assert query_output == query_smiles # Check 2: Query mols given as a string query_smiles_str = ("CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1," "CCN(CC)C(=S)SSC(=S)N(CC)CC") _ = predictor.predict(query_smiles_str, support_actives_smiles, support_inactives_smiles) query_output = predictor._return_query_mols_as_list() assert query_output == query_smiles # Check 3: Query mols given as a pd.Series query_smiles_series = pd.DataFrame({"smiles": ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"]}) _ = predictor.predict(query_smiles_series, support_actives_smiles, support_inactives_smiles) query_output = predictor._return_query_mols_as_list() assert query_output == query_smiles # Check 4: Query molecules storage is None predictor.query_molecules = None with pytest.raises(ValueError): predictor._return_query_mols_as_list() # Check 5: Other data types predictor.query_molecules = 123 # any other data type with pytest.raises(TypeError): predictor._return_query_mols_as_list()