""" This module provides a simple predict function for the MHNfs model. It loads the model from the provided checkpoint, creates necessary helper inputs and makes predictions for a list of molecules """ #--------------------------------------------------------------------------------------- # Dependencies import pandas as pd import pytorch_lightning as pl import streamlit as st from src.data_preprocessing.create_model_inputs import (create_query_input, create_support_set_input) from src.mhnfs.model import MHNfs #--------------------------------------------------------------------------------------- # Define predictor class class ActivityPredictor: def __init__(self): @st.cache_resource # Caching for streamlit def load_model(): pl.seed_everything(1234) current_loc = __file__.rsplit("/",2)[0] model = MHNfs.load_from_checkpoint(current_loc + "/assets/mhnfs_data/" "mhnfs_checkpoint.ckpt") model._update_context_set_embedding() model.eval() return model # Load model self.model = load_model() # Initiate query mol storage self.query_molecules = None def predict(self, query_smiles, support_activces_smiles, support_inactives_smiles): # Create model inputs # Query input self.query_molecules = query_smiles query_input = create_query_input(query_smiles) # Active support set input support_actives_input, support_actives_size = create_support_set_input( support_activces_smiles ) # Inactive support set input support_inactives_input, support_inactives_size = create_support_set_input( support_inactives_smiles ) # Make predictions predictions = self.model( query_input, support_actives_input, support_inactives_input, support_actives_size, support_inactives_size, ) preds_numpy = predictions.detach().numpy().flatten() return preds_numpy def _return_query_mols_as_list(self): if isinstance(self.query_molecules, list): return self.query_molecules elif isinstance(self.query_molecules, str): smiles_list = self.query_molecules.split(",") smiles_list_cleaned = [smiles.strip() for smiles in smiles_list] return smiles_list_cleaned elif isinstance(self.query_molecules, pd.DataFrame): return self.query_molecules.smiles.tolist() elif isinstance(self.query_molecules, type(None)): raise ValueError("No query molecules have been stored yet." "Run predict-function first.") else: raise TypeError("Type of query molecules not recognized." "Please check input type.") #--------------------------------------------------------------------------------------- if __name__ == "__main__": # Create predictor predictor = ActivityPredictor() # Create example inputs query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"] support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"] support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Make predictions predictions = predictor.predict(query_smiles, support_actives_smiles, support_inactives_smiles) print(predictions)