Source code for data.stats

from typing import Dict, Tuple

import torch
from tqdm import tqdm

import schnetpack.properties as properties
from schnetpack.data import AtomsLoader

__all__ = ["calculate_stats", "estimate_atomrefs"]


[docs]def calculate_stats( dataloader: AtomsLoader, divide_by_atoms: Dict[str, bool], atomref: Dict[str, torch.Tensor] = None, ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: """ Use the incremental Welford algorithm described in [h1]_ to accumulate the mean and standard deviation over a set of samples. References: ----------- .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Args: dataloader: data loader divide_by_atoms: dict from property name to bool: If True, divide property by number of atoms before calculating statistics. atomref: reference values for single atoms to be removed before calculating stats Returns: Mean and standard deviation over all samples """ property_names = list(divide_by_atoms.keys()) norm_mask = torch.tensor( [float(divide_by_atoms[p]) for p in property_names], dtype=torch.float64 ) count = 0 mean = torch.zeros_like(norm_mask) M2 = torch.zeros_like(norm_mask) for props in tqdm(dataloader, "calculating statistics"): sample_values = [] for p in property_names: val = props[p][None, :] if atomref and p in atomref.keys(): ar = atomref[p] ar = ar[props[properties.Z]] idx_m = props[properties.idx_m] tmp = torch.zeros((idx_m[-1] + 1,), dtype=ar.dtype, device=ar.device) v0 = tmp.index_add(0, idx_m, ar) val -= v0 sample_values.append(val) sample_values = torch.cat(sample_values, dim=0) batch_size = sample_values.shape[1] new_count = count + batch_size norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + ( 1 - norm_mask[:, None] ) sample_values /= norm sample_mean = torch.mean(sample_values, dim=1) sample_m2 = torch.sum((sample_values - sample_mean[:, None]) ** 2, dim=1) delta = sample_mean - mean mean += delta * batch_size / new_count corr = batch_size * count / new_count M2 += sample_m2 + delta**2 * corr count = new_count stddev = torch.sqrt(M2 / count) stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} return stats
def estimate_atomrefs(dataloader, is_extensive, z_max=100): """ Uses linear regression to estimate the elementwise biases (atomrefs). Args: dataloader: data loader is_extensive: If True, divide atom type counts by number of atoms before calculating statistics. Returns: Elementwise bias estimates over all samples """ property_names = list(is_extensive.keys()) n_data = len(dataloader.dataset) all_properties = {pname: torch.zeros(n_data) for pname in property_names} all_atom_types = torch.zeros((n_data, z_max)) data_counter = 0 # loop over all batches for batch in tqdm(dataloader, "estimating atomrefs"): # load data idx_m = batch[properties.idx_m] atomic_numbers = batch[properties.Z] # get counts for atomic numbers unique_ids = torch.unique(idx_m) for i in unique_ids: atomic_numbers_i = atomic_numbers[idx_m == i] atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True) # save atom counts and properties for atom_type, atom_count in zip(atom_types, atom_counts): all_atom_types[data_counter, atom_type] = atom_count for pname in property_names: property_value = batch[pname][i] if not is_extensive[pname]: property_value *= batch[properties.n_atoms][i] all_properties[pname][data_counter] = property_value data_counter += 1 # perform linear regression to get the elementwise energy contributions existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0] X = torch.squeeze(all_atom_types[:, existing_atom_types]) w = dict() for pname in property_names: if is_extensive[pname]: w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname] else: w[pname] = ( torch.linalg.inv(X.T @ X) @ X.T @ (all_properties[pname] / X.sum(axis=1)) ) # compute energy estimates elementwise_contributions = { pname: torch.zeros((z_max)) for pname in property_names } for pname in property_names: for atom_type, weight in zip(existing_atom_types, w[pname]): elementwise_contributions[pname][atom_type] = weight return elementwise_contributions