Source code for datasets.tmqm

import io
import logging
import os
import re
import shutil
import tarfile
import tempfile
from typing import List, Optional, Dict
from urllib import request as request
import gzip

import numpy as np
from ase import Atoms
from ase.io.extxyz import read_xyz
from ase.io import read
from tqdm import tqdm

import torch
from schnetpack.data import *
import schnetpack.properties as structure
from schnetpack.data import AtomsDataModuleError, AtomsDataModule

__all__ = ["TMQM"]


[docs]class TMQM(AtomsDataModule): """tmQM database of Ballcells 2020 of inorganic CSD structures. References: .. [#tmqm_1] https://pubs.acs.org/doi/10.1021/acs.jcim.0c01041 """ # properties # electronic and dispersion energies, # highest occupied molecular orbital (HOMO) and lowest unoccupied molecular orbital (LUMO) energies, # HOMO/LUMO gap, # dipole moment, and natural charge of the metal center; GFN2-xTB polarizabilities are also provided. # these strings match the names in the header of the csv file csd_code = "CSD_code" # should go into key-value pair energy = "Electronic_E" dispersion = "Dispersion_E" homo = "HOMO_Energy" lumo = "LUMO_Energy" gap = "HL_Gap" mu = "Dipole_M" alpha = "Polarizability" qm = "Metal_q" def __init__( self, datapath: str, batch_size: int, num_train: Optional[int] = None, num_val: Optional[int] = None, num_test: Optional[int] = None, split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, train_transforms: Optional[List[torch.nn.Module]] = None, val_transforms: Optional[List[torch.nn.Module]] = None, test_transforms: Optional[List[torch.nn.Module]] = None, num_workers: int = 2, num_val_workers: Optional[int] = None, num_test_workers: Optional[int] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, data_workdir: Optional[str] = None, **kwargs, ): """ Args: datapath: path to dataset batch_size: (train) batch size num_train: number of training examples num_val: number of validation examples num_test: number of test examples split_file: path to npz file with data partitions format: dataset format load_properties: subset of properties to load remove_uncharacterized: do not include uncharacterized molecules. val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. transforms: Transform applied to each system separately before batching. train_transforms: Overrides transform_fn for training. val_transforms: Overrides transform_fn for validation. test_transforms: Overrides transform_fn for testing. num_workers: Number of data loader workers. num_val_workers: Number of validation data loader workers (overrides num_workers). num_test_workers: Number of test data loader workers (overrides num_workers). property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. """ super().__init__( datapath=datapath, batch_size=batch_size, num_train=num_train, num_val=num_val, num_test=num_test, split_file=split_file, format=format, load_properties=load_properties, val_batch_size=val_batch_size, test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, num_workers=num_workers, num_val_workers=num_val_workers, num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, data_workdir=data_workdir, **kwargs, ) def prepare_data(self): if not os.path.exists(self.datapath): property_unit_dict = { TMQM.energy: "Ha", TMQM.dispersion: "Ha", TMQM.homo: "Ha", TMQM.lumo: "Ha", TMQM.gap: "Ha", TMQM.mu: "Debye", TMQM.alpha: "a0 a0 a0", TMQM.qm: "e", } tmpdir = tempfile.mkdtemp("tmQM") dataset = create_dataset( datapath=self.datapath, format=self.format, distance_unit="Ang", property_unit_dict=property_unit_dict, ) self._download_data(tmpdir, dataset) shutil.rmtree(tmpdir) else: dataset = load_dataset(self.datapath, self.format) def _download_data(self, tmpdir, dataset: BaseAtomsData): tar_path = os.path.join(tmpdir, "tmQM_X1.xyz.gz") url = [ "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X1.xyz.gz", "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X2.xyz.gz", ] url_y = "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_y.csv" tmp_xyz_file = os.path.join(tmpdir, "tmQM_X.xyz") tmp_properties_file = os.path.join(tmpdir, "tmQM_y.csv") atomslist = [] for u in url: request.urlretrieve(u, tar_path) with gzip.open(tar_path, "rb") as f_in: with open(tmp_xyz_file, "wb") as f_out: lines = f_in.readlines() # remove empty lines lines = [line for line in lines if line.strip()] f_out.writelines(lines) atomslist.extend(read(tmp_xyz_file, index=":")) # download proeprties in tmQM_y.csv request.urlretrieve(url_y, tmp_properties_file) # CSV format # CSD_code;Electronic_E;Dispersion_E;Dipole_M;Metal_q;HL_Gap;HOMO_Energy;LUMO_Energy;Polarizability # WIXKOE;-2045.524942;-0.239239;4.233300;2.109340;0.131080;-0.162040;-0.030960;598.457913 # DUCVIG;-2430.690317;-0.082134;11.754400;0.759940;0.124930;-0.243580;-0.118650;277.750698 # KINJOG;-3467.923206;-0.137954;8.301700;1.766500;0.140140;-0.236460;-0.096320;393.442545 # read csv prop_list = [] key_value_pairs_list = [] with open(tmp_properties_file, "r") as file: lines = file.readlines() keys = lines[0].strip("\n").split(";") for l in lines[1:]: properties = l.split(";") prop_dict = { k: np.array([float(v)]) for k, v in zip(keys[1:], properties[1:]) } key_value_pairs = {k: v for k, v in zip(keys[0], properties[0])} prop_list.append(prop_dict) key_value_pairs_list.append(key_value_pairs) dataset.add_systems(property_list=prop_list, atoms_list=atomslist)