""" | |
In this file, the input functions for query and support set molecules are defined. | |
Input is assumed to be either a SMILES string, a list of SMILES strings, or a pandas | |
dataframe. | |
""" | |
#--------------------------------------------------------------------------------------- | |
# Dependencies | |
import pandas as pd | |
from typing import List | |
import torch | |
from src.data_preprocessing.create_descriptors import preprocess_molecules | |
#--------------------------------------------------------------------------------------- | |
# Define main functions | |
def create_query_input(smiles_input: [str, List[str], pd.DataFrame]): | |
""" | |
This function creates the input for the query molecules. | |
""" | |
# Create vector representation | |
numpy_vector_representation = preprocess_molecules(smiles_input) | |
assert len(numpy_vector_representation.shape) == 2 | |
# Create pytorch tensor | |
tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(1).float() | |
return tensor | |
def create_support_set_input(smiles_input: [str, List[str], pd.DataFrame]): | |
""" | |
This function creates the input for the support set molecules. | |
""" | |
# Create vector representation | |
numpy_vector_representation = preprocess_molecules(smiles_input) | |
assert len(numpy_vector_representation.shape) == 2 | |
size = numpy_vector_representation.shape[0] | |
# Create pytorch tensors | |
tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(0).float() | |
size = torch.tensor(size) | |
return tensor, size |