File size: 2,723 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os.path as osp
import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
from tqdm import tqdm
class MD22(InMemoryDataset):
def __init__(self, root, dataset_arg=None, transform=None, pre_transform=None):
self.dataset_arg = dataset_arg
super(MD22, self).__init__(osp.join(root, dataset_arg), transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def molecule_names(self):
molecule_names = dict(
Ac_Ala3_NHMe="md22_Ac-Ala3-NHMe.npz",
DHA="md22_DHA.npz",
stachyose="md22_stachyose.npz",
AT_AT="md22_AT-AT.npz",
AT_AT_CG_CG="md22_AT-AT-CG-CG.npz",
buckyball_catcher="md22_buckyball-catcher.npz",
double_walled_nanotube="md22_dw_nanotube.npz"
)
return molecule_names
@property
def raw_file_names(self):
return [self.molecule_names[self.dataset_arg]]
@property
def processed_file_names(self):
return [f"md22_{self.dataset_arg}.pt"]
@property
def base_url(self):
return "http://www.quantum-machine.org/gdml/data/npz/"
def download(self):
download_url(self.base_url + self.molecule_names[self.dataset_arg], self.raw_dir)
def process(self):
for path, processed_path in zip(self.raw_paths, self.processed_paths):
data_npz = np.load(path)
z = torch.from_numpy(data_npz["z"]).long()
positions = torch.from_numpy(data_npz["R"]).float()
energies = torch.from_numpy(data_npz["E"]).float()
forces = torch.from_numpy(data_npz["F"]).float()
samples = []
for pos, y, dy in tqdm(zip(positions, energies, forces), total=energies.size(0)):
data = Data(z=z, pos=pos, y=y.unsqueeze(1), dy=dy)
if self.pre_filter is not None:
data = self.pre_filter(data)
if self.pre_transform is not None:
data = self.pre_transform(data)
samples.append(data)
data, slices = self.collate(samples)
torch.save((data, slices), processed_path)
@property
def molecule_splits(self):
"""
Splits refer to MD22 https://arxiv.org/pdf/2209.14865.pdf
"""
return dict(
Ac_Ala3_NHMe=6000,
DHA=8000,
stachyose=8000,
AT_AT=3000,
AT_AT_CG_CG=2000,
buckyball_catcher=600,
double_walled_nanotube=800
) |