mhnfs / src /data_preprocessing /create_model_inputs.py
Tschoui's picture
move project from private to public space
cf004a6
raw
history blame contribute delete
No virus
1.53 kB
"""
In this file, the input functions for query and support set molecules are defined.
Input is assumed to be either a SMILES string, a list of SMILES strings, or a pandas
dataframe.
"""
#---------------------------------------------------------------------------------------
# Dependencies
import pandas as pd
from typing import List
import torch
from src.data_preprocessing.create_descriptors import preprocess_molecules
#---------------------------------------------------------------------------------------
# Define main functions
def create_query_input(smiles_input: [str, List[str], pd.DataFrame]):
"""
This function creates the input for the query molecules.
"""
# Create vector representation
numpy_vector_representation = preprocess_molecules(smiles_input)
assert len(numpy_vector_representation.shape) == 2
# Create pytorch tensor
tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(1).float()
return tensor
def create_support_set_input(smiles_input: [str, List[str], pd.DataFrame]):
"""
This function creates the input for the support set molecules.
"""
# Create vector representation
numpy_vector_representation = preprocess_molecules(smiles_input)
assert len(numpy_vector_representation.shape) == 2
size = numpy_vector_representation.shape[0]
# Create pytorch tensors
tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(0).float()
size = torch.tensor(size)
return tensor, size