Source code for datasets.md22

import torch
from typing import Optional, Dict, List

from schnetpack.data import *
from schnetpack.datasets.md17 import GDMLDataModule


all = ["MD22"]


[docs]class MD22(GDMLDataModule): """ MD22 benchmark data set for extended molecules containing molecular forces. References: .. [#md22_1] http://quantum-machine.org/gdml/#datasets """ def __init__( self, datapath: str, molecule: str, batch_size: int, num_train: Optional[int] = None, num_val: Optional[int] = None, num_test: Optional[int] = None, split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, train_transforms: Optional[List[torch.nn.Module]] = None, val_transforms: Optional[List[torch.nn.Module]] = None, test_transforms: Optional[List[torch.nn.Module]] = None, num_workers: int = 2, num_val_workers: Optional[int] = None, num_test_workers: Optional[int] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, data_workdir: Optional[str] = None, **kwargs, ): """ Args: datapath: path to dataset batch_size: (train) batch size num_train: number of training examples num_val: number of validation examples num_test: number of test examples split_file: path to npz file with data partitions format: dataset format load_properties: subset of properties to load val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. transforms: Transform applied to each system separately before batching. train_transforms: Overrides transform_fn for training. val_transforms: Overrides transform_fn for validation. test_transforms: Overrides transform_fn for testing. num_workers: Number of data loader workers. num_val_workers: Number of validation data loader workers (overrides num_workers). num_test_workers: Number of test data loader workers (overrides num_workers). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. """ atomrefs = { self.energy: [ 0.0, -313.5150902000774, 0.0, 0.0, 0.0, 0.0, -23622.587180094913, -34219.46811826416, -47069.30768969713, ] } datasets_dict = { "Ac-Ala3-NHMe": "md22_Ac-Ala3-NHMe.npz", "DHA": "md22_DHA.npz", "stachyose": "md22_stachyose.npz", "AT-AT": "md22_AT-AT.npz", "AT-AT-CG-CG": "md22_AT-AT-CG-CG.npz", "buckyball-catcher": "md22_buckyball-catcher.npz", "double-walled_nanotube": "md22_double-walled_nanotube.npz", } super(MD22, self).__init__( datasets_dict=datasets_dict, download_url="http://www.quantum-machine.org/gdml/repo/datasets/", tmpdir="md22", molecule=molecule, datapath=datapath, batch_size=batch_size, num_train=num_train, num_val=num_val, num_test=num_test, split_file=split_file, format=format, load_properties=load_properties, val_batch_size=val_batch_size, test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, num_workers=num_workers, num_val_workers=num_val_workers, num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, data_workdir=data_workdir, atomrefs=atomrefs, **kwargs, )