Source code for datasets.omdb

import logging
import os
import tarfile
from typing import List, Optional, Dict
from ase.io import read

import numpy as np

import torch
from schnetpack.data import *
from schnetpack.data import AtomsDataModuleError, AtomsDataModule


__all__ = ["OrganicMaterialsDatabase"]


[docs]class OrganicMaterialsDatabase(AtomsDataModule): """ Organic Materials Database (OMDB) of bulk organic crystals. Registration to the OMDB is free for academic users. This database contains DFT (PBE) band gap (OMDB-GAP1 database) for 12500 non-magnetic materials. References: .. [#omdb] Bart Olsthoorn, R. Matthias Geilhufe, Stanislav S. Borysov, Alexander V. Balatsky. Band gap prediction for large organic crystal structures with machine learning. https://arxiv.org/abs/1810.12814 """ BandGap = "band_gap" def __init__( self, datapath: str, batch_size: int, num_train: Optional[int] = None, num_val: Optional[int] = None, num_test: Optional[int] = None, split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, train_transforms: Optional[List[torch.nn.Module]] = None, val_transforms: Optional[List[torch.nn.Module]] = None, test_transforms: Optional[List[torch.nn.Module]] = None, num_workers: int = 2, num_val_workers: Optional[int] = None, num_test_workers: Optional[int] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, raw_path: Optional[str] = None, **kwargs ): """ Args: datapath: path to dataset batch_size: (train) batch size num_train: number of training examples num_val: number of validation examples num_test: number of test examples split_file: path to npz file with data partitions format: dataset format load_properties: subset of properties to load val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. transforms: Transform applied to each system separately before batching. train_transforms: Overrides transform_fn for training. val_transforms: Overrides transform_fn for validation. test_transforms: Overrides transform_fn for testing. num_workers: Number of data loader workers. num_val_workers: Number of validation data loader workers (overrides num_workers). num_test_workers: Number of test data loader workers (overrides num_workers). property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). raw_path: path to raw tar.gz file with the data """ super().__init__( datapath=datapath, batch_size=batch_size, num_train=num_train, num_val=num_val, num_test=num_test, split_file=split_file, format=format, load_properties=load_properties, val_batch_size=val_batch_size, test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, num_workers=num_workers, num_val_workers=num_val_workers, num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, **kwargs ) self.raw_path = raw_path def prepare_data(self): if not os.path.exists(self.datapath): property_unit_dict = {OrganicMaterialsDatabase.BandGap: "eV"} dataset = create_dataset( datapath=self.datapath, format=self.format, distance_unit="Ang", property_unit_dict=property_unit_dict, ) self._convert(dataset) else: dataset = load_dataset(self.datapath, self.format) def _convert(self, dataset): """ Converts .tar.gz to a .db file """ if self.raw_path is None or not os.path.exists(self.raw_path): # TODO: can we download here automatically like QM9? raise AtomsDataModuleError( "The path to the raw dataset is not provided or invalid and the db-file does " "not exist!" ) logging.info("Converting %s to a .db file.." % self.raw_path) tar = tarfile.open(self.raw_path, "r:gz") names = tar.getnames() tar.extractall() tar.close() structures = read("structures.xyz", index=":") Y = np.loadtxt("bandgaps.csv") [os.remove(name) for name in names] atoms_list = [] property_list = [] for i, at in enumerate(structures): atoms_list.append(at) property_list.append({OrganicMaterialsDatabase.BandGap: np.array([Y[i]])}) dataset.add_systems(atoms_list=atoms_list, property_list=property_list)