File size: 3,779 Bytes
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
This module provides a simple predict function for the MHNfs model.
It loads the model from the provided checkpoint, creates necessary helper inputs
and makes predictions for a list of molecules
"""

#---------------------------------------------------------------------------------------
# Dependencies
import pandas as pd
import pytorch_lightning as pl
import streamlit as st

from src.data_preprocessing.create_model_inputs import (create_query_input,
                                                    create_support_set_input)
from src.mhnfs.model import MHNfs

#---------------------------------------------------------------------------------------
# Define predictor class

class ActivityPredictor:
    
    def __init__(self):

        @st.cache_resource # Caching for streamlit
        def load_model():
            pl.seed_everything(1234)
            current_loc = __file__.rsplit("/",2)[0]
            model = MHNfs.load_from_checkpoint(current_loc +
                                                    "/assets/mhnfs_data/"
                                                    "mhnfs_checkpoint.ckpt")
            model._update_context_set_embedding()
            model.eval()
    
            return model
        
        # Load model
        self.model = load_model()
        
        # Initiate query mol storage
        self.query_molecules = None
        
    def predict(self, query_smiles, support_activces_smiles, support_inactives_smiles):
        
        # Create model inputs
        # Query input
        self.query_molecules = query_smiles
        query_input = create_query_input(query_smiles)
        
        # Active support set input
        support_actives_input, support_actives_size = create_support_set_input(
            support_activces_smiles
        )
        
        # Inactive support set input
        support_inactives_input, support_inactives_size = create_support_set_input(
            support_inactives_smiles
        )
        
        # Make predictions
        predictions = self.model(
            query_input,
            support_actives_input,
            support_inactives_input,
            support_actives_size,
            support_inactives_size,
        )
        
        preds_numpy = predictions.detach().numpy().flatten()
        
        
        return preds_numpy
        
    def _return_query_mols_as_list(self):
        if isinstance(self.query_molecules, list):
            return self.query_molecules
        elif isinstance(self.query_molecules, str):
            smiles_list = self.query_molecules.split(",")
            smiles_list_cleaned = [smiles.strip() for smiles in smiles_list]
            return smiles_list_cleaned
        elif isinstance(self.query_molecules, pd.DataFrame):
            return self.query_molecules.smiles.tolist()
        elif isinstance(self.query_molecules, type(None)):
            raise ValueError("No query molecules have been stored yet."
                             "Run predict-function first.")
        else:
            raise TypeError("Type of query molecules not recognized."
                            "Please check input type.")
            
#---------------------------------------------------------------------------------------
if __name__ == "__main__":
    # Create predictor
    predictor = ActivityPredictor()
    
    # Create example inputs
    query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"]
    support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"]
    support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"]
    
    # Make predictions
    predictions = predictor.predict(query_smiles,
                                    support_actives_smiles,
                                    support_inactives_smiles)
    
    print(predictions)