mhnfs / src /prediction_pipeline.py
Tschoui's picture
move project from private to public space
cf004a6
"""
This module provides a simple predict function for the MHNfs model.
It loads the model from the provided checkpoint, creates necessary helper inputs
and makes predictions for a list of molecules
"""
#---------------------------------------------------------------------------------------
# Dependencies
import pandas as pd
import pytorch_lightning as pl
import streamlit as st
from src.data_preprocessing.create_model_inputs import (create_query_input,
create_support_set_input)
from src.mhnfs.model import MHNfs
#---------------------------------------------------------------------------------------
# Define predictor class
class ActivityPredictor:
def __init__(self):
@st.cache_resource # Caching for streamlit
def load_model():
pl.seed_everything(1234)
current_loc = __file__.rsplit("/",2)[0]
model = MHNfs.load_from_checkpoint(current_loc +
"/assets/mhnfs_data/"
"mhnfs_checkpoint.ckpt")
model._update_context_set_embedding()
model.eval()
return model
# Load model
self.model = load_model()
# Initiate query mol storage
self.query_molecules = None
def predict(self, query_smiles, support_activces_smiles, support_inactives_smiles):
# Create model inputs
# Query input
self.query_molecules = query_smiles
query_input = create_query_input(query_smiles)
# Active support set input
support_actives_input, support_actives_size = create_support_set_input(
support_activces_smiles
)
# Inactive support set input
support_inactives_input, support_inactives_size = create_support_set_input(
support_inactives_smiles
)
# Make predictions
predictions = self.model(
query_input,
support_actives_input,
support_inactives_input,
support_actives_size,
support_inactives_size,
)
preds_numpy = predictions.detach().numpy().flatten()
return preds_numpy
def _return_query_mols_as_list(self):
if isinstance(self.query_molecules, list):
return self.query_molecules
elif isinstance(self.query_molecules, str):
smiles_list = self.query_molecules.split(",")
smiles_list_cleaned = [smiles.strip() for smiles in smiles_list]
return smiles_list_cleaned
elif isinstance(self.query_molecules, pd.DataFrame):
return self.query_molecules.smiles.tolist()
elif isinstance(self.query_molecules, type(None)):
raise ValueError("No query molecules have been stored yet."
"Run predict-function first.")
else:
raise TypeError("Type of query molecules not recognized."
"Please check input type.")
#---------------------------------------------------------------------------------------
if __name__ == "__main__":
# Create predictor
predictor = ActivityPredictor()
# Create example inputs
query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"]
support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"]
support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"]
# Make predictions
predictions = predictor.predict(query_smiles,
support_actives_smiles,
support_inactives_smiles)
print(predictions)