""" In this file, the input functions for query and support set molecules are defined. Input is assumed to be either a SMILES string, a list of SMILES strings, or a pandas dataframe. """ #--------------------------------------------------------------------------------------- # Dependencies import pandas as pd from typing import List import torch from src.data_preprocessing.create_descriptors import preprocess_molecules #--------------------------------------------------------------------------------------- # Define main functions def create_query_input(smiles_input: [str, List[str], pd.DataFrame]): """ This function creates the input for the query molecules. """ # Create vector representation numpy_vector_representation = preprocess_molecules(smiles_input) assert len(numpy_vector_representation.shape) == 2 # Create pytorch tensor tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(1).float() return tensor def create_support_set_input(smiles_input: [str, List[str], pd.DataFrame]): """ This function creates the input for the support set molecules. """ # Create vector representation numpy_vector_representation = preprocess_molecules(smiles_input) assert len(numpy_vector_representation.shape) == 2 size = numpy_vector_representation.shape[0] # Create pytorch tensors tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(0).float() size = torch.tensor(size) return tensor, size