File size: 2,723 Bytes
62a2f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os.path as osp

import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
from tqdm import tqdm


class MD22(InMemoryDataset):
    def __init__(self, root, dataset_arg=None, transform=None, pre_transform=None):
        
        self.dataset_arg = dataset_arg
        
        super(MD22, self).__init__(osp.join(root, dataset_arg), transform, pre_transform)
        
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def molecule_names(self):
        
        molecule_names = dict(
            Ac_Ala3_NHMe="md22_Ac-Ala3-NHMe.npz",
            DHA="md22_DHA.npz",
            stachyose="md22_stachyose.npz",
            AT_AT="md22_AT-AT.npz",
            AT_AT_CG_CG="md22_AT-AT-CG-CG.npz",
            buckyball_catcher="md22_buckyball-catcher.npz",
            double_walled_nanotube="md22_dw_nanotube.npz"
        )

        return molecule_names

    @property
    def raw_file_names(self):
        return [self.molecule_names[self.dataset_arg]]

    @property
    def processed_file_names(self):
        return [f"md22_{self.dataset_arg}.pt"]
    
    @property
    def base_url(self):
        return "http://www.quantum-machine.org/gdml/data/npz/"

    def download(self):
        
        download_url(self.base_url + self.molecule_names[self.dataset_arg], self.raw_dir)
            
    def process(self):
        for path, processed_path in zip(self.raw_paths, self.processed_paths):
            data_npz = np.load(path)
            z = torch.from_numpy(data_npz["z"]).long()
            positions = torch.from_numpy(data_npz["R"]).float()
            energies = torch.from_numpy(data_npz["E"]).float()
            forces = torch.from_numpy(data_npz["F"]).float()

            samples = []
            for pos, y, dy in tqdm(zip(positions, energies, forces), total=energies.size(0)):
                
                data = Data(z=z, pos=pos, y=y.unsqueeze(1), dy=dy)

                if self.pre_filter is not None:
                    data = self.pre_filter(data)

                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                    
                samples.append(data)

            data, slices = self.collate(samples)
            torch.save((data, slices), processed_path)
    
    @property
    def molecule_splits(self):
        """
            Splits refer to MD22 https://arxiv.org/pdf/2209.14865.pdf
        """
        return dict(
            Ac_Ala3_NHMe=6000,
            DHA=8000,
            stachyose=8000,
            AT_AT=3000,
            AT_AT_CG_CG=2000,
            buckyball_catcher=600,
            double_walled_nanotube=800
        )