File size: 3,117 Bytes
592e96e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pandas as pd
from tqdm import tqdm
from rdkit import Chem, RDLogger
from datasets import load_dataset
from multiprocessing import Pool, cpu_count
import os

# Suppress RDKit console output for cleaner logs
RDLogger.DisableLog('rdApp.*')

class SmilesEnumerator:
    """
    A simple class to encapsulate the SMILES randomization logic.
    Needed for multiprocessing to work correctly with instance methods.
    """
    def randomize_smiles(self, smiles):
        """Generates a randomized SMILES string."""
        try:
            mol = Chem.MolFromSmiles(smiles)
            # Return a randomized, non-canonical SMILES string
            return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles
        except:
            # If RDKit fails, return the original smiles string
            return smiles

def create_augmented_pair(smiles_string):
    """
    Worker function: takes one SMILES string and returns a tuple
    containing two different randomized versions of it.
    """
    enumerator = SmilesEnumerator()
    smiles_1 = enumerator.randomize_smiles(smiles_string)
    smiles_2 = enumerator.randomize_smiles(smiles_string)
    return smiles_1, smiles_2

def main():
    """
    Main function to run the parallel data preprocessing.
    """
    # --- Configuration ---
    # Load your desired dataset from Hugging Face
    dataset_name = 'jablonkagroup/pubchem-smiles-molecular-formula'
    # Specify the column containing the SMILES strings
    smiles_column_name = 'smiles'
    # Set the output file path
    output_path = 'data/pubchem_computed_110_end_M.parquet'

    # --- Data Loading ---
    print(f"Loading dataset '{dataset_name}'...")
    # Use streaming to avoid downloading the whole dataset if you only need a subset
    dataset = load_dataset(dataset_name, split='train').select(range(110_000_000, ))
    
    # Take the desired number of samples
    smiles_list = dataset[smiles_column_name]
    print(f"Successfully fetched {len(smiles_list)} SMILES strings.")

    # --- Parallel Processing ---
    # Use all available CPU cores for maximum speed
    num_workers = cpu_count()
    print(f"Starting SMILES augmentation with {num_workers} worker processes...")

    # A Pool of processes will run the `create_augmented_pair` function in parallel
    with Pool(num_workers) as p:
        # Use tqdm to create a progress bar for the mapping operation
        results = list(tqdm(p.imap(create_augmented_pair, smiles_list), total=len(smiles_list), desc="Augmenting Pairs"))

    # --- Saving Data ---
    print("Processing complete. Converting to DataFrame...")
    # Convert the list of tuples into a pandas DataFrame
    df = pd.DataFrame(results, columns=['smiles_1', 'smiles_2'])

    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    print(f"Saving augmented pairs to '{output_path}'...")
    # Save the DataFrame to a Parquet file for efficient storage and loading
    df.to_parquet(output_path)
    
    print("All done. Your pre-computed dataset is ready!")

if __name__ == '__main__':
    main()