Source code for datasets.qm7x

import hashlib
import logging
import lzma
import os
import re
import shutil
import tempfile
from typing import Dict, List, Optional
from urllib import request as request

import h5py
import numpy as np
import progressbar
import torch
from ase import Atoms
from tqdm import tqdm

from schnetpack.data import *
from schnetpack.data import AtomsDataModule
from schnetpack.data.splitting import GroupSplit

__all__ = ["QM7X"]

# Helper functions
pbar = None


def show_progress(block_num: int, block_size: int, total_size: int):
    """
    progress callback for files downloads
    """
    global pbar
    if pbar is None:
        pbar = progressbar.ProgressBar(maxval=total_size)
        pbar.start()

    downloaded = block_num * block_size
    if downloaded < total_size:
        pbar.update(downloaded)
    else:
        pbar.finish()
        pbar = None


def download_and_check(url: str, tar_path: str, checksum: str):
    """
    Download file from url to tar_path and check md5 checksum.
    """

    file = url.split("/")[-1]

    # check if file already exists and has correct checksum
    if os.path.exists(tar_path):
        md5_sum = hashlib.md5(open(tar_path, "rb").read()).hexdigest()
        if md5_sum == checksum:
            logging.info(
                f"File {file} already exists and has correct checksum. Skipping download."
            )
            return
        else:
            logging.info(
                f"File {file} already exists but has wrong checksum. Redownloading."
            )
            os.remove(tar_path)

    logging.info(f"Downloading {url} ...")
    request.urlretrieve(url, tar_path, show_progress)

    if hashlib.md5(open(tar_path, "rb").read()).hexdigest() == checksum:
        logging.info("Done.")
    else:
        raise RuntimeError(
            f"Checksum of downloaded file {file} does not match. Please try again."
        )


def extract_xz(source: str, target: str):
    """
    helper to extract xz files.
    """
    s_file = source.split("/")[-1]
    t_file = target.split("/")[-1]

    if os.path.exists(target):
        logging.info(f"File {t_file} already exists. Skipping extraction.")
        return

    logging.info(f"Extracting {s_file} ...")

    try:
        with lzma.open(source) as fin, open(target, mode="wb") as fout:
            shutil.copyfileobj(fin, fout)
    except:
        if os.path.exists(target):
            os.remove(target)
        raise RuntimeError(f"Could not extract file {s_file}. Please try again.")

    logging.info("Done.")


[docs]class QM7X(AtomsDataModule): """ QM7-X a comprehensive dataset of > 40 physicochemical properties for ~4.2 M equilibrium and non-equilibrium structure of small organic molecules with up to seven non-hydrogen (C, N, O, S, Cl) atoms. This class adds convenient functions to download QM7-X and load the data into pytorch. References: .. [#qm7x_1] https://zenodo.org/record/4288677 """ # more molecular and atomic properties can be found in the original paper and added here # Notice that adding more properties can drastically increase the size of the dataset # adding more properties here requires to add them to the property_unit_dict # and there key mapping in the raw dataset in property_dataset_keys. forces = "forces" # total ePBE0+MBD forces energy = "energy" # ePBE0+MBD: total energy after convergence of the PBE0 exchange-correlation functional and the MBD dispersion correction Eat = "Eat" # atomization energy using PBE0 energy per atom and ePBE0+MBD total energy EPBE0 = "EPBE0" # ePBE0: total energy at the level of PBE0 EMBD = "EMBD" # eMBD: total energy at the level of MBD FPBE0 = "FMBD" # FPBE0: total ePBE0 forces FMBD = "FMBD" # FMBD: total eMBD forces RMSD = "rmsd" # root mean square deviation of the atomic positions from the equilibrium structure property_unit_dict = { forces: "eV/Ang", energy: "eV", Eat: "eV", EPBE0: "eV", EMBD: "eV", FPBE0: "eV/Ang", FMBD: "eV/Ang", RMSD: "Ang", } # the original keys in the raw dataset to query the properties property_dataset_keys = { forces: "totFOR", energy: "ePBE0+MBD", Eat: "eAT", EPBE0: "ePBE0", EMBD: "eMBD", FPBE0: "pbe0FOR", FMBD: "vdwFOR", RMSD: "sRMSD", } # atom energies (atomrefs) from PBE0 EPBE0_atom = { 1: -13.641404161, 6: -1027.592489146, 7: -1484.274819088, 8: -2039.734879322, 16: -10828.707468187, 17: -12516.444619523, } def __init__( self, datapath: str, batch_size: int, raw_data_path: str = None, remove_duplicates: bool = True, only_equilibrium: bool = False, only_non_equilibrium: bool = False, num_train: Optional[int] = None, num_val: Optional[int] = None, num_test: Optional[int] = None, split_file: Optional[str] = "split.npz", format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, load_properties: Optional[List[str]] = None, val_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, transforms: Optional[List[torch.nn.Module]] = None, train_transforms: Optional[List[torch.nn.Module]] = None, val_transforms: Optional[List[torch.nn.Module]] = None, test_transforms: Optional[List[torch.nn.Module]] = None, num_workers: int = 2, num_val_workers: Optional[int] = None, num_test_workers: Optional[int] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, data_workdir: Optional[str] = None, splitting: Optional[SplittingStrategy] = None, **kwargs, ): """ Args: datapath: path to dataset batch_size: (train) batch size raw_data_path: path to raw data. If None use tmp dir otherwise persist data and not remove it. remove_duplicates: remove duplicated equilibrium structures with different non-equilibrium structures only_equilibrium: only use equilibrium structures num_train: number of training examples num_val: number of validation examples num_test: number of test examples split_file: path to npz file with data partitions format: dataset format load_properties: subset of properties to load val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. transforms: Transform applied to each system separately before batching. train_transforms: Overrides transform_fn for training. val_transforms: Overrides transform_fn for validation. test_transforms: Overrides transform_fn for testing. num_workers: Number of data loader workers. num_val_workers: Number of validation data loader workers (overrides num_workers). num_test_workers: Number of test data loader workers (overrides num_workers). property_units: Dictionary from property to corresponding unit as a string (eV, kcal/mol, ...). distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. splitting: Method to generate train/validation/test partitions (default: GroupSplit(splitting_key="smiles_id")) """ super().__init__( datapath=datapath, batch_size=batch_size, num_train=num_train, num_val=num_val, num_test=num_test, split_file=split_file, format=format, load_properties=load_properties, val_batch_size=val_batch_size, test_batch_size=test_batch_size, transforms=transforms, train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, num_workers=num_workers, num_val_workers=num_val_workers, num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, data_workdir=data_workdir, splitting=splitting or GroupSplit(splitting_key="smiles_id"), **kwargs, ) self.raw_data_path = raw_data_path self.remove_duplicates = remove_duplicates self.duplicates_ids = None self.only_equilibrium = only_equilibrium self.only_non_equilibrium = only_non_equilibrium def _download_duplicates_ids(self, tar_dir: str): """ download duplicates ids for QM7-X """ url = f"https://zenodo.org/record/4288677/files/DupMols.dat" tar_path = os.path.join(tar_dir, "DupMols.dat") checksum = "5d886ccac38877c8cb26c07704dd1034" download_and_check(url, tar_path, checksum) # fetch duplicates ids dup_mols = [] for line in open(tar_path, "r"): dup_mols.append(line.rstrip("\n")[:-4]) self.duplicates_ids = dup_mols def _download_data(self, tar_dir: str, ignore_extracted: bool = True) -> List[str]: """ download data and extract them """ file_ids = ["1000", "2000", "3000", "4000", "5000", "6000", "7000", "8000"] # file fingerprints to check integrity checksums = [ "b50c6a5d0a4493c274368cf22285503e", "4418a813daf5e0d44aa5a26544249ee6", "f7b5aac39a745f11436047c12d1eb24e", "26819601705ef8c14080fa7fc69decd4", "85ac444596b87812aaa9e48d203d0b70", "787fc4a9036af0e67c034a30ad854c07", "5ecce00a188410d06b747cb683d8d347", "c893ae88b8f5c32541c3f024fc1daa45", ] logging.info("Downloading QM7-X data files ...") # download files for i, file_id in enumerate(file_ids): if ignore_extracted and os.path.exists( os.path.join(tar_dir, f"{file_id}.hdf5") ): logging.info( f"File {file_id}.xz exists in extracted version {file_id}.hdf5 already, skipping download." ) continue url = f"https://zenodo.org/record/4288677/files/{file_id}.xz" tar_path = os.path.join(tar_dir, f"{file_id}.xz") download_and_check(url, tar_path, checksums[i]) # extract the compressed files extracted = [] for i, file_id in enumerate(file_ids): xz_path = os.path.join(tar_dir, f"{file_id}.xz") hd_path = os.path.join(tar_dir, f"{file_id}.hdf5") extract_xz(xz_path, hd_path) extracted.append(hd_path) return extracted def _parse_data(self, files: List[str], dataset: BaseAtomsData): """ Parse the downloaded data files and add them to the dataset. """ # parse the data files for file in files: logging.info(f"Parsing {file.split('/')[-1]} ...") atoms_list = [] property_list = [] groups_ids = { "smiles_id": [], "stereo_iso_id": [], "conform_id": [], "step_id": [], } with h5py.File(file, "r") as mol_dict: for mol_id, mol in tqdm(mol_dict.items()): for conf_id, conf in mol.items(): # exclude equilibrium duplicates trunc_id = conf_id[::-1].split("-", 1)[-1][::-1] if self.remove_duplicates and trunc_id in self.duplicates_ids: continue ats = Atoms(positions=conf["atXYZ"], numbers=conf["atNUM"]) properties = { key: np.array( conf[QM7X.property_dataset_keys[key]], dtype=np.float64 ) for key in QM7X.property_unit_dict.keys() } # get the hierarchical ids for each system if "opt" in conf_id: conf_id = ( conf_id[:-3] + "d0" ) # repalce the 'opt' key with id 'd0' ids = map(lambda x: int(x), re.findall(r"\d+", conf_id)) atoms_list.append(ats) property_list.append(properties) # save the hierarchical ids for each system in same order as the systems for i, j in zip(groups_ids.keys(), ids): groups_ids[i].append(j) # add the data to the dataset logging.info(f"Write parsed data from {file.split('/')[-1]} to db ...") dataset.add_systems(property_list=property_list, atoms_list=atoms_list) # add the hierarchical ids to the metadata md = dataset.metadata if "groups_ids" in md.keys(): for key, ids in groups_ids.items(): groups_ids[key] = md["groups_ids"][key] + ids # add the ids as in the database of the new added systems last_id = md["groups_ids"]["id"][-1] sys_ids = list(range(last_id + 1, last_id + len(atoms_list) + 1)) groups_ids["id"] = md["groups_ids"]["id"] + sys_ids else: groups_ids["id"] = list(range(1, len(atoms_list) + 1)) dataset.update_metadata(groups_ids=groups_ids) logging.info("Done.") def prepare_data(self): """ prepare data for pytorch lightning data module """ if not os.path.exists(self.datapath): tar_dir = self.raw_data_path or tempfile.mkdtemp("qm7x") atomrefs = { QM7X.energy: [ QM7X.EPBE0_atom[i] if i in QM7X.EPBE0_atom else 0.0 for i in range(0, 18) ] } dataset = create_dataset( datapath=self.datapath, format=self.format, distance_unit="Ang", property_unit_dict=QM7X.property_unit_dict, atomrefs=atomrefs, ) hd_files = self._download_data(tar_dir) if self.remove_duplicates: self._download_duplicates_ids(tar_dir) self._parse_data(hd_files, dataset) if self.raw_data_path is None: shutil.rmtree(tar_dir) def setup(self, stage=None): if self.data_workdir is None: datapath = self.datapath else: datapath = self._copy_to_workdir() # (re)load datasets if self.dataset is None: self.dataset = load_dataset( datapath, self.format, property_units=self.property_units, distance_unit=self.distance_unit, load_properties=self.load_properties, ) # use subset of equilibrium structures if self.only_equilibrium or self.only_non_equilibrium: step_ids = self.dataset.metadata["groups_ids"]["step_id"] if len(step_ids) != len(self.dataset): raise ValueError( "The dataset size does not match the size of step ids arrays in meta data." ) if self.only_equilibrium: eq_indices = [i for i, s in enumerate(step_ids) if s == 0] else: eq_indices = [i for i, s in enumerate(step_ids) if s != 0] self.dataset = self.dataset.subset(eq_indices) # load and generate partitions if needed if self.train_idx is None: self._load_partitions() # partition dataset self._train_dataset = self.dataset.subset(self.train_idx) self._val_dataset = self.dataset.subset(self.val_idx) self._test_dataset = self.dataset.subset(self.test_idx) self._setup_transforms()